File size: 14,644 Bytes
e0336bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
import os
import re
from typing import Optional
import torch
from safetensors.torch import load_file
from tqdm import tqdm

import logging

from utils.safetensors_utils import MemoryEfficientSafeOpen

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

from modules.fp8_optimization_utils import optimize_state_dict_with_fp8_on_the_fly


def merge_lora_to_state_dict(
    model_file: str,
    lora_files: Optional[list[str]],
    multipliers: Optional[list[float]],
    fp8_optimization: bool,
    device: torch.device,
    move_to_device: bool = False,
) -> dict[str, torch.Tensor]:
    """
    Merge LoRA weights into the state dict of a model.
    """

    # if the file name ends with 00001-of-00004 etc, we need to load the files with the same prefix
    basename = os.path.basename(model_file)
    match = re.match(r"^(.*?)(\d+)-of-(\d+)\.safetensors$", basename)
    if match:
        prefix = basename[: match.start(2)]
        count = int(match.group(3))
        model_files = [os.path.normpath(model_file)]
        for i in range(count):
            file_name = f"{prefix}{i+1:05d}-of-{count:05d}.safetensors"
            file_path = os.path.join(os.path.dirname(model_file), file_name)
            file_path = os.path.normpath(file_path)
            if os.path.exists(file_path) and file_path not in model_files:
                model_files.append(file_path)
        logger.info(f"Loading split weights: {model_files}")
    else:
        model_files = [os.path.normpath(model_file)]

    list_of_lora_sd = []
    if lora_files is not None:
        for lora_file in lora_files:
            # Load LoRA safetensors file
            lora_sd = load_file(lora_file)

            # Check the format of the LoRA file
            keys = list(lora_sd.keys())
            if keys[0].startswith("lora_unet_"):
                logging.info(f"Musubi Tuner LoRA detected")

            else:
                transformer_prefixes = ["diffusion_model", "transformer"]  # to ignore Text Encoder modules
                lora_suffix = None
                prefix = None
                for key in keys:
                    if lora_suffix is None and "lora_A" in key:
                        lora_suffix = "lora_A"
                    if prefix is None:
                        pfx = key.split(".")[0]
                        if pfx in transformer_prefixes:
                            prefix = pfx
                    if lora_suffix is not None and prefix is not None:
                        break

                if lora_suffix == "lora_A" and prefix is not None:
                    logging.info(f"Diffusion-pipe (?) LoRA detected")
                    lora_sd = convert_from_diffusion_pipe_or_something(lora_sd, "lora_unet_")

                else:
                    logging.info(f"LoRA file format not recognized: {os.path.basename(lora_file)}")
                    lora_sd = None

            if lora_sd is not None:
                # Check LoRA is for FramePack or for HunyuanVideo
                is_hunyuan = False
                for key in lora_sd.keys():
                    if "double_blocks" in key or "single_blocks" in key:
                        is_hunyuan = True
                        break
                if is_hunyuan:
                    logging.info("HunyuanVideo LoRA detected, converting to FramePack format")
                    lora_sd = convert_hunyuan_to_framepack(lora_sd)

            if lora_sd is not None:
                list_of_lora_sd.append(lora_sd)

    if len(list_of_lora_sd) == 0:
        # no LoRA files found, just load the model
        return load_safetensors_with_fp8_optimization(model_files, fp8_optimization, device, move_to_device, weight_hook=None)

    return load_safetensors_with_lora_and_fp8(model_files, list_of_lora_sd, multipliers, fp8_optimization, device, move_to_device)


def convert_from_diffusion_pipe_or_something(lora_sd: dict[str, torch.Tensor], prefix: str) -> dict[str, torch.Tensor]:
    """
    Convert LoRA weights to the format used by the diffusion pipeline to Musubi Tuner.
    Copy from Musubi Tuner repo.
    """
    # convert from diffusers(?) to default LoRA
    # Diffusers format: {"diffusion_model.module.name.lora_A.weight": weight, "diffusion_model.module.name.lora_B.weight": weight, ...}
    # default LoRA format: {"prefix_module_name.lora_down.weight": weight, "prefix_module_name.lora_up.weight": weight, ...}

    # note: Diffusers has no alpha, so alpha is set to rank
    new_weights_sd = {}
    lora_dims = {}
    for key, weight in lora_sd.items():
        diffusers_prefix, key_body = key.split(".", 1)
        if diffusers_prefix != "diffusion_model" and diffusers_prefix != "transformer":
            print(f"unexpected key: {key} in diffusers format")
            continue

        new_key = f"{prefix}{key_body}".replace(".", "_").replace("_lora_A_", ".lora_down.").replace("_lora_B_", ".lora_up.")
        new_weights_sd[new_key] = weight

        lora_name = new_key.split(".")[0]  # before first dot
        if lora_name not in lora_dims and "lora_down" in new_key:
            lora_dims[lora_name] = weight.shape[0]

    # add alpha with rank
    for lora_name, dim in lora_dims.items():
        new_weights_sd[f"{lora_name}.alpha"] = torch.tensor(dim)

    return new_weights_sd


def load_safetensors_with_lora_and_fp8(
    model_files: list[str],
    list_of_lora_sd: list[dict[str, torch.Tensor]],
    multipliers: Optional[list[float]],
    fp8_optimization: bool,
    device: torch.device,
    move_to_device: bool = False,
) -> dict[str, torch.Tensor]:
    """
    Merge LoRA weights into the state dict of a model with fp8 optimization if needed.
    """
    if multipliers is None:
        multipliers = [1.0] * len(list_of_lora_sd)
    if len(multipliers) > len(list_of_lora_sd):
        multipliers = multipliers[: len(list_of_lora_sd)]
    if len(multipliers) < len(list_of_lora_sd):
        multipliers += [1.0] * (len(list_of_lora_sd) - len(multipliers))
    multipliers = [float(m) for m in multipliers]

    list_of_lora_weight_keys = []
    for lora_sd in list_of_lora_sd:
        lora_weight_keys = set(lora_sd.keys())
        list_of_lora_weight_keys.append(lora_weight_keys)

    # Merge LoRA weights into the state dict
    print(f"Merging LoRA weights into state dict on the fly. multipliers: {multipliers}")

    # make hook for LoRA merging
    def weight_hook(model_weight_key, model_weight):
        nonlocal list_of_lora_weight_keys, list_of_lora_sd, multipliers

        if not model_weight_key.endswith(".weight"):
            return model_weight

        original_device = model_weight.device
        if original_device != device:
            model_weight = model_weight.to(device)  # to make calculation faster

        for lora_weight_keys, lora_sd, multiplier in zip(list_of_lora_weight_keys, list_of_lora_sd, multipliers):
            # check if this weight has LoRA weights
            lora_name = model_weight_key.rsplit(".", 1)[0]  # remove trailing ".weight"
            lora_name = "lora_unet_" + lora_name.replace(".", "_")
            down_key = lora_name + ".lora_down.weight"
            up_key = lora_name + ".lora_up.weight"
            alpha_key = lora_name + ".alpha"
            if down_key not in lora_weight_keys or up_key not in lora_weight_keys:
                return model_weight

            # get LoRA weights
            down_weight = lora_sd[down_key]
            up_weight = lora_sd[up_key]

            dim = down_weight.size()[0]
            alpha = lora_sd.get(alpha_key, dim)
            scale = alpha / dim

            down_weight = down_weight.to(device)
            up_weight = up_weight.to(device)

            # W <- W + U * D
            if len(model_weight.size()) == 2:
                # linear
                if len(up_weight.size()) == 4:  # use linear projection mismatch
                    up_weight = up_weight.squeeze(3).squeeze(2)
                    down_weight = down_weight.squeeze(3).squeeze(2)
                model_weight = model_weight + multiplier * (up_weight @ down_weight) * scale
            elif down_weight.size()[2:4] == (1, 1):
                # conv2d 1x1
                model_weight = (
                    model_weight
                    + multiplier
                    * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
                    * scale
                )
            else:
                # conv2d 3x3
                conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
                # logger.info(conved.size(), weight.size(), module.stride, module.padding)
                model_weight = model_weight + multiplier * conved * scale

            # remove LoRA keys from set
            lora_weight_keys.remove(down_key)
            lora_weight_keys.remove(up_key)
            if alpha_key in lora_weight_keys:
                lora_weight_keys.remove(alpha_key)

        model_weight = model_weight.to(original_device)  # move back to original device
        return model_weight

    state_dict = load_safetensors_with_fp8_optimization(
        model_files, fp8_optimization, device, move_to_device, weight_hook=weight_hook
    )

    for lora_weight_keys in list_of_lora_weight_keys:
        if len(lora_weight_keys) > 0:
            # if there are still LoRA keys left, it means they are not used in the model
            # this is a warning, not an error
            logger.warning(f"Warning: {len(lora_weight_keys)} LoRA keys not used in the model: {lora_weight_keys}")

    return state_dict


def convert_hunyuan_to_framepack(lora_sd: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
    """
    Convert HunyuanVideo LoRA weights to FramePack format.
    """
    new_lora_sd = {}
    for key, weight in lora_sd.items():
        if "double_blocks" in key:
            key = key.replace("double_blocks", "transformer_blocks")
            key = key.replace("img_mod_linear", "norm1_linear")
            key = key.replace("img_attn_qkv", "attn_to_QKV")  # split later
            key = key.replace("img_attn_proj", "attn_to_out_0")
            key = key.replace("img_mlp_fc1", "ff_net_0_proj")
            key = key.replace("img_mlp_fc2", "ff_net_2")
            key = key.replace("txt_mod_linear", "norm1_context_linear")
            key = key.replace("txt_attn_qkv", "attn_add_QKV_proj")  # split later
            key = key.replace("txt_attn_proj", "attn_to_add_out")
            key = key.replace("txt_mlp_fc1", "ff_context_net_0_proj")
            key = key.replace("txt_mlp_fc2", "ff_context_net_2")
        elif "single_blocks" in key:
            key = key.replace("single_blocks", "single_transformer_blocks")
            key = key.replace("linear1", "attn_to_QKVM")  # split later
            key = key.replace("linear2", "proj_out")
            key = key.replace("modulation_linear", "norm_linear")
        else:
            print(f"Unsupported module name: {key}, only double_blocks and single_blocks are supported")
            continue

        if "QKVM" in key:
            # split QKVM into Q, K, V, M
            key_q = key.replace("QKVM", "q")
            key_k = key.replace("QKVM", "k")
            key_v = key.replace("QKVM", "v")
            key_m = key.replace("attn_to_QKVM", "proj_mlp")
            if "_down" in key or "alpha" in key:
                # copy QKVM weight or alpha to Q, K, V, M
                assert "alpha" in key or weight.size(1) == 3072, f"QKVM weight size mismatch: {key}. {weight.size()}"
                new_lora_sd[key_q] = weight
                new_lora_sd[key_k] = weight
                new_lora_sd[key_v] = weight
                new_lora_sd[key_m] = weight
            elif "_up" in key:
                # split QKVM weight into Q, K, V, M
                assert weight.size(0) == 21504, f"QKVM weight size mismatch: {key}. {weight.size()}"
                new_lora_sd[key_q] = weight[:3072]
                new_lora_sd[key_k] = weight[3072 : 3072 * 2]
                new_lora_sd[key_v] = weight[3072 * 2 : 3072 * 3]
                new_lora_sd[key_m] = weight[3072 * 3 :]  # 21504 - 3072 * 3 = 12288
            else:
                print(f"Unsupported module name: {key}")
                continue
        elif "QKV" in key:
            # split QKV into Q, K, V
            key_q = key.replace("QKV", "q")
            key_k = key.replace("QKV", "k")
            key_v = key.replace("QKV", "v")
            if "_down" in key or "alpha" in key:
                # copy QKV weight or alpha to Q, K, V
                assert "alpha" in key or weight.size(1) == 3072, f"QKV weight size mismatch: {key}. {weight.size()}"
                new_lora_sd[key_q] = weight
                new_lora_sd[key_k] = weight
                new_lora_sd[key_v] = weight
            elif "_up" in key:
                # split QKV weight into Q, K, V
                assert weight.size(0) == 3072 * 3, f"QKV weight size mismatch: {key}. {weight.size()}"
                new_lora_sd[key_q] = weight[:3072]
                new_lora_sd[key_k] = weight[3072 : 3072 * 2]
                new_lora_sd[key_v] = weight[3072 * 2 :]
            else:
                print(f"Unsupported module name: {key}")
                continue
        else:
            # no split needed
            new_lora_sd[key] = weight

    return new_lora_sd


def load_safetensors_with_fp8_optimization(
    model_files: list[str], fp8_optimization: bool, device: torch.device, move_to_device: bool, weight_hook: callable = None
) -> dict[str, torch.Tensor]:
    """
    Load state dict from safetensors files and merge LoRA weights into the state dict with fp8 optimization if needed.
    """
    if fp8_optimization:
        TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
        EXCLUDE_KEYS = ["norm"]  # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8
        state_dict = optimize_state_dict_with_fp8_on_the_fly(
            model_files, device, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=move_to_device, weight_hook=weight_hook
        )
    else:
        state_dict = {}
        for model_file in model_files:
            with MemoryEfficientSafeOpen(model_file) as f:
                for key in tqdm(f.keys(), desc=f"Loading {model_file}", leave=False):
                    value = f.get_tensor(key)
                    if weight_hook is not None:
                        value = weight_hook(key, value)
                    if move_to_device:
                        value = value.to(device)
                    state_dict[key] = value

    return state_dict