# Copyright 2024 Bytedance Ltd. and/or its affiliates # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import os import re from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Tuple import numpy as np import torch from torch.distributed._tensor import DTensor, Placement, Shard from transformers import ( AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForVision2Seq, PretrainedConfig, PreTrainedModel, ) def merge_by_placement(tensors: List[torch.Tensor], placement: Placement): if placement.is_replicate(): return tensors[0] elif placement.is_partial(): raise NotImplementedError("Partial placement is not supported yet") elif placement.is_shard(): return torch.cat(tensors, dim=placement.dim).contiguous() else: raise ValueError(f"Unsupported placement: {placement}") def upload_model_to_huggingface(local_path: str, remote_path: str): # Push to hugging face from huggingface_hub import HfApi api = HfApi() api.create_repo(repo_id=remote_path, private=False, exist_ok=True) api.upload_folder(repo_id=remote_path, folder_path=local_path, repo_type="model") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--local_dir", required=True, type=str, help="The path for your saved model") parser.add_argument("--hf_upload_path", default=False, type=str, help="The path of the huggingface repo to upload") args = parser.parse_args() local_dir: str = args.local_dir assert not local_dir.endswith("huggingface"), "The local_dir should not end with huggingface." # copy rank zero to find the shape of (dp, fsdp) rank = 0 world_size = 0 for filename in os.listdir(local_dir): match = re.match(r"model_world_size_(\d+)_rank_0\.pt", filename) if match: world_size = match.group(1) break assert world_size, "No model file with the proper format." rank0_weight_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") state_dict = torch.load(rank0_weight_path, map_location="cpu", weights_only=False) pivot_key = sorted(state_dict.keys())[0] weight = state_dict[pivot_key] if isinstance(weight, DTensor): # get sharding info device_mesh = weight.device_mesh mesh = device_mesh.mesh mesh_dim_names = device_mesh.mesh_dim_names else: # for non-DTensor mesh = np.array([int(world_size)], dtype=np.int64) mesh_dim_names = ("fsdp",) print(f"Got device mesh {mesh}, mesh_dim_names {mesh_dim_names}") assert mesh_dim_names in (("fsdp",), ("ddp", "fsdp")), f"Unsupported mesh_dim_names {mesh_dim_names}." if "tp" in mesh_dim_names: # fsdp * tp total_shards = mesh.shape[-1] * mesh.shape[-2] mesh_shape = (mesh.shape[-2], mesh.shape[-1]) else: # fsdp total_shards = mesh.shape[-1] mesh_shape = (mesh.shape[-1],) print(f"Processing {total_shards} model shards in total.") model_state_dict_lst = [] model_state_dict_lst.append(state_dict) model_state_dict_lst.extend([""] * (total_shards - 1)) def process_one_shard(rank, model_state_dict_lst): model_path = os.path.join(local_dir, f"model_world_size_{world_size}_rank_{rank}.pt") state_dict = torch.load(model_path, map_location="cpu", weights_only=False) model_state_dict_lst[rank] = state_dict return state_dict with ThreadPoolExecutor(max_workers=min(32, os.cpu_count())) as executor: for rank in range(1, total_shards): executor.submit(process_one_shard, rank, model_state_dict_lst) state_dict: Dict[str, List[torch.Tensor]] = {} param_placements: Dict[str, List[Placement]] = {} keys = set(model_state_dict_lst[0].keys()) for key in keys: state_dict[key] = [] for model_state_dict in model_state_dict_lst: try: tensor = model_state_dict.pop(key) except Exception: print(f"Cannot find key {key} in rank {rank}.") if isinstance(tensor, DTensor): state_dict[key].append(tensor._local_tensor.bfloat16()) placements = tuple(tensor.placements) # replicated placement at ddp dimension can be discarded if mesh_dim_names[0] == "ddp": placements = placements[1:] if key not in param_placements: param_placements[key] = placements else: assert param_placements[key] == placements else: state_dict[key].append(tensor.bfloat16()) del model_state_dict_lst for key in sorted(state_dict): if not isinstance(state_dict[key], list): print(f"No need to merge key {key}") continue if key in param_placements: # merge shards placements: Tuple[Shard] = param_placements[key] if len(mesh_shape) == 1: # 1-D list, FSDP without TP assert len(placements) == 1 shards = state_dict[key] state_dict[key] = merge_by_placement(shards, placements[0]) else: # 2-D list, FSDP + TP raise NotImplementedError("FSDP + TP is not supported yet.") else: state_dict[key] = torch.cat(state_dict[key], dim=0) print("Merge completed.") hf_path = os.path.join(local_dir, "huggingface") config: PretrainedConfig = AutoConfig.from_pretrained(hf_path) architectures: List[str] = getattr(config, "architectures", ["Unknown"]) if "ForTokenClassification" in architectures[0]: AutoClass = AutoModelForTokenClassification elif "ForCausalLM" in architectures[0]: AutoClass = AutoModelForCausalLM elif "ForConditionalGeneration" in architectures[0]: AutoClass = AutoModelForVision2Seq else: raise NotImplementedError(f"Unknown architecture {architectures}.") with torch.device("meta"): model: PreTrainedModel = AutoClass.from_config(config, torch_dtype=torch.bfloat16) assert isinstance(model, PreTrainedModel) model.to_empty(device="cpu") print(f"Saving model to {hf_path}...") model.save_pretrained(hf_path, state_dict=state_dict) del state_dict, model if args.hf_upload_path: upload_model_to_huggingface(hf_path, args.hf_upload_path)