# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. """Pretrain GPT.""" import warnings warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=FutureWarning) warnings.filterwarnings("ignore") import inspect import os from contextlib import nullcontext from functools import partial from typing import List, Optional, Tuple, Union import torch from megatron.core import mpu from megatron.core.datasets.blended_megatron_dataset_builder import ( BlendedMegatronDatasetBuilder, ) from megatron.core.datasets.gpt_dataset import ( GPTDataset, GPTDatasetConfig, MockGPTDataset, ) from megatron.core.datasets.utils import get_blend_from_list from megatron.core.enums import ModelType from megatron.core.models.gpt.gpt_layer_specs import ( get_gpt_decoder_block_spec, get_gpt_layer_local_spec, get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec, ) from megatron.core.transformer.spec_utils import import_module from megatron.core.utils import StragglerDetector from megatron.training import ( get_args, get_timers, get_tokenizer, pretrain, print_rank_0, ) from megatron.training.arguments import core_transformer_config_from_args from megatron.training.initialize import initialize_megatron from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank from megatron.training.yaml_arguments import core_transformer_config_from_yaml from moe_mem_estimator.base import ( get_pipeline_model_parallel_rank, get_pipeline_model_parallel_world_size, get_virtual_pipeline_model_parallel_world_size, is_pipeline_first_stage, is_pipeline_last_stage, set_global_config, set_pipeline_model_parallel_rank, ) from moe_mem_estimator.gpt_model import GPTModel from moe_mem_estimator.layers import MLASelfAttention, MoELayer torch.distributed.get_rank = lambda: 0 torch.cuda.get_device_capability = lambda: [8] def estimate_from_config(config, args): """ Estimate memory usage from a given config and args, instead of global state. Now supports virtual pipeline model parallelism for more accurate results. """ args.moe_grouped_gemm = True patch_parallel_states() if config is None: if args.yaml_cfg is not None: config = core_transformer_config_from_yaml(args, "language_model") else: config = core_transformer_config_from_args(args) input_shape = [args.micro_batch_size, args.seq_length] set_global_config(config) print(config) # return cli_reports = [] if config.pipeline_model_parallel_size > 1: for pp_rank in range(config.pipeline_model_parallel_size): set_pipeline_model_parallel_rank(pp_rank) print( f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------" ) input_shape, rpt = report_memory_usage_one_pp_rank( input_shape, args, config, pp_rank, config.pipeline_model_parallel_size ) cli_reports.append(rpt) else: set_pipeline_model_parallel_rank(0) _, rpt = report_memory_usage_one_pp_rank(input_shape, args, config) cli_reports.append(rpt) aggregated_reports: list[dict] = cli_reports # 返回 (聚合后的 pp 报告列表, 全量 raw chunk 列表) return aggregated_reports, cli_reports def _get_transformer_layer_spec(use_te, config): """Get transformer layer specification based on configuration. Args: use_te (bool): Whether to use Transformer Engine args: Training arguments config: Model configuration Returns: transformer_layer_spec: The transformer layer specification """ if use_te: return get_gpt_layer_with_transformer_engine_spec( config.num_moe_experts, config.moe_grouped_gemm, config.qk_layernorm, config.multi_latent_attention, config.fp8, ) else: return get_gpt_layer_local_spec( config.num_moe_experts, config.moe_grouped_gemm, config.qk_layernorm, config.multi_latent_attention, ) def model_provider( args, config, pre_process=True, post_process=True, vp_stage: Optional[int] = None ) -> GPTModel: use_te = True if args.num_experts: # Define the decoder block spec transformer_layer_spec = get_gpt_decoder_block_spec( config, use_transformer_engine=use_te, normalization="LayerNorm", qk_l2_norm=False, vp_stage=vp_stage, ) else: # Define the decoder layer spec transformer_layer_spec = _get_transformer_layer_spec(use_te, config) mtp_block_spec = None # TODO fp8 model = GPTModel( config=config, transformer_layer_spec=transformer_layer_spec, vocab_size=args.padded_vocab_size, max_sequence_length=args.max_position_embeddings, pre_process=pre_process, post_process=post_process, fp16_lm_cross_entropy=getattr(config, "fp16_lm_cross_entropy", False), parallel_output=True, share_embeddings_and_output_weights=False, position_embedding_type="rope", rotary_percent=getattr(args, "rotary_percent", 1.0), rotary_base=getattr(args, "rotary_base", 10000), rope_scaling=getattr(config, "use_rope_scaling", False), mtp_block_spec=mtp_block_spec, vp_stage=vp_stage, ) return model def get_model( model_provider_func, args, config, model_type=ModelType.encoder_or_decoder ): """Build the model.""" # args = get_args() # args.model_type = model_type # Build model. if not getattr(args, "virtual_pipeline_model_parallel_size", None): args.virtual_pipeline_model_parallel_size = None if config.pipeline_model_parallel_layout: args.virtual_pipeline_model_parallel_size = ( config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size ) config.virtual_pipeline_model_parallel_size = ( config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size ) def build_model(): if ( get_pipeline_model_parallel_world_size() > 1 and args.virtual_pipeline_model_parallel_size is not None ): if model_type == ModelType.encoder_and_decoder: assert ( config.encoder_pipeline_model_parallel_size == 0 ), "Interleaved schedule not supported for model with encoder on separate PP rank" model = [] for i in range(args.virtual_pipeline_model_parallel_size): # Set pre_process and post_process only after virtual rank is set. pre_process = is_pipeline_first_stage(ignore_virtual=False, vp_stage=i) post_process = is_pipeline_last_stage(ignore_virtual=False, vp_stage=i) this_model = model_provider_func( args, config, pre_process=pre_process, post_process=post_process, vp_stage=i, ) this_model.model_type = model_type this_model.vp_stage = i model.append(this_model) else: pre_process = is_pipeline_first_stage() post_process = is_pipeline_last_stage() if model_type == ModelType.encoder_and_decoder: if get_pipeline_model_parallel_world_size() > 1: rank = get_pipeline_model_parallel_rank() first_decoder_rank = config.encoder_pipeline_model_parallel_size world_size = get_pipeline_model_parallel_world_size() pre_process = rank == 0 or rank == first_decoder_rank post_process = (rank == (first_decoder_rank - 1)) or ( rank == (world_size - 1) ) model = model_provider_func( args, config, pre_process=pre_process, post_process=post_process, ) else: model = model_provider_func( args, config, pre_process=pre_process, post_process=post_process ) model.model_type = model_type return model model = build_model() if not isinstance(model, list): model = [model] return model NUM_BYTES_IN_MEGABYTE = 1024 * 1024 NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024 def patch_parallel_states(): from megatron.core import parallel_state parallel_state.is_pipeline_first_stage = is_pipeline_first_stage parallel_state.is_pipeline_last_stage = is_pipeline_last_stage parallel_state.get_pipeline_model_parallel_rank = get_pipeline_model_parallel_rank parallel_state.get_pipeline_model_parallel_world_size = ( get_pipeline_model_parallel_world_size ) parallel_state.get_virtual_pipeline_model_parallel_world_size = ( get_virtual_pipeline_model_parallel_world_size ) parallel_state.is_inside_encoder = lambda: False parallel_state.get_pipeline_model_parallel_decoder_start = lambda: 0 def report_memory_usage(args, config=None): args.moe_grouped_gemm = True patch_parallel_states() if config is None: if args.yaml_cfg is not None: config = core_transformer_config_from_yaml(args, "language_model") else: config = core_transformer_config_from_args(args) input_shape = [args.micro_batch_size, args.seq_length] set_global_config(config) cli_reports = [] if config.pipeline_model_parallel_size > 1: for pp_rank in range(config.pipeline_model_parallel_size): set_pipeline_model_parallel_rank(pp_rank) print( f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------" ) input_shape, rpt = report_memory_usage_one_pp_rank( input_shape, args, config, pp_rank, config.pipeline_model_parallel_size ) cli_reports.append(rpt) else: set_pipeline_model_parallel_rank(0) _, rpt = report_memory_usage_one_pp_rank(input_shape, args, config) cli_reports.append(rpt) # Optionally pretty print summary print("\n===== Summary (per PP rank) =====") for r in cli_reports: print( f"PP{r['pp_rank']} total {r['total_gb']} GB (weight_grad {r['weight_grad_gb']} GB weight_grad_optim {r['weight_grad_optim_gb']} GB act {r['activation_gb']} GB)" ) def report_memory_usage_one_pp_rank( input_shape: list[int], args, config, pp_rank=0, pp_size=1 ) -> tuple[list[int], dict]: print(f"{input_shape=}") model: list[GPTModel] = get_model(model_provider, args, config) num_parameter_this_shard_all = 0 num_parameter_this_shard_sparse_all = 0 num_activation_all = 0 output_shape = input_shape for vpp_rank, one_chunk in enumerate(model): num_parameter_this_shard = one_chunk.num_parameter() num_activation = one_chunk.num_activation(output_shape) output_shape = one_chunk.mock_forward(output_shape) print(f"{output_shape=}") num_parameter_this_shard_sparse = 0 for layer in one_chunk.decoder.layers.modules: if isinstance(layer.mlp, MoELayer): num_parameter_this_shard_sparse += layer.mlp.num_parameter() if ( "shared_experts" in layer.mlp.__dir__() and layer.mlp.shared_experts is not None ): num_parameter_this_shard_sparse -= ( layer.mlp.shared_experts.num_parameter() ) num_activation_this_shard_mlp = sum( [m.mlp.num_activation() for m in one_chunk.decoder.layers.modules] ) if len(model) > 1: if vpp_rank >= 1 and vpp_rank < len(model) - 1: num_microbatch_this_pp_rank = pp_size elif vpp_rank == 0: num_microbatch_this_pp_rank = pp_size + max( (pp_size - pp_rank) * 2 - 1 - pp_size, 0 ) elif vpp_rank == len(model) - 1: num_microbatch_this_pp_rank = min((pp_size - pp_rank) * 2 + 1, pp_size) else: num_microbatch_this_pp_rank = pp_size - pp_rank num_parameter_this_shard_sparse = 0 for layer in one_chunk.decoder.layers.modules: if isinstance(layer.mlp, MoELayer): num_parameter_this_shard_sparse += layer.mlp.num_parameter() if ( "shared_experts" in layer.mlp.__dir__() and layer.mlp.shared_experts is not None ): num_parameter_this_shard_sparse -= ( layer.mlp.shared_experts.num_parameter() ) one_chunk.__repr__() print(one_chunk) print( f"Number of parameters in every GPU in billions: " f"{num_parameter_this_shard / 10**9: .2f} where mlp part is {num_parameter_this_shard_sparse / 10**9: .2f}" ) num_parameter_this_shard_all += num_parameter_this_shard num_parameter_this_shard_sparse_all += num_parameter_this_shard_sparse # recompute if config.recompute_granularity == "full": recompute_num_layers = config.recompute_num_layers num_layers = one_chunk.num_layers common_act = ( one_chunk.num_act_pre + one_chunk.num_act_between_layers * num_layers * num_microbatch_this_pp_rank ) # recompute with pipeline parallel info = "With this recomputing setting, the number of activation achieve peak when " if config.recompute_method == "block": num_layers_with_loss = num_layers - recompute_num_layers if num_layers_with_loss == 0: peak1 = common_act + one_chunk.num_act_post peak2 = common_act + one_chunk.num_act_per_layer if peak1 > peak2: info += "calculating loss" else: info += "back-propogating loss" num_activation = max(peak1, peak2) else: info += f"calculating loss with {num_layers_with_loss} non-recompute layers" num_activation = ( common_act + one_chunk.num_act_post + one_chunk.num_act_per_layer * num_layers_with_loss * num_microbatch_this_pp_rank ) elif config.recompute_method == "uniform": peak1 = common_act + one_chunk.num_act_post peak2 = ( (common_act + one_chunk.num_act_per_layer) if vpp_rank == 0 else (common_act) ) if peak1 > peak2: info += "calculating loss" else: info += f"back-propogating loss recomputing every {recompute_num_layers} layers" num_activation = max(peak1, peak2) if len(one_chunk.decoder.layers.modules) > 0 and isinstance( one_chunk.decoder.layers.modules[0].self_attention, MLASelfAttention ): # MLA recompute achieve peak at backward num_activation += one_chunk.decoder.layers.modules[ 0 ].self_attention.core_attention.num_activation() print(info) else: num_activation = ( num_activation - one_chunk.num_act_post ) * num_microbatch_this_pp_rank + one_chunk.num_act_post # CP num_activation = num_activation / config.context_parallel_size if pp_size == 1: print( f"Number of activation in every GPU in billions: " f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}" ) else: print( f"Number of activation per microbatch in every GPU in billions: " f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}" f", {num_microbatch_this_pp_rank=} {vpp_rank=}" ) num_activation_all += num_activation num_bytes_per_parameter = ( 18 if not args.use_distributed_optimizer else 6 + (12 / args.data_parallel_size / config.context_parallel_size) ) if config.expert_model_parallel_size * config.expert_tensor_parallel_size > 1: num_bytes_per_parameter_dense = num_bytes_per_parameter num_bytes_per_parameter_moe = ( 18 if not args.use_distributed_optimizer else 6 + ( 12 / ( args.world_size / config.pipeline_model_parallel_size / config.expert_model_parallel_size / config.expert_tensor_parallel_size ) ) ) print(f"{num_bytes_per_parameter_dense=} {num_bytes_per_parameter_moe=}") weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE weight_grad_optim_memory = ( (num_parameter_this_shard_all - num_parameter_this_shard_sparse_all) * num_bytes_per_parameter_dense + num_parameter_this_shard_sparse_all * num_bytes_per_parameter_moe ) / NUM_BYTES_IN_GIGABYTE else: print(f"{num_bytes_per_parameter=}") weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE weight_grad_optim_memory = ( num_parameter_this_shard_all * num_bytes_per_parameter / NUM_BYTES_IN_GIGABYTE ) activation_memory = ( num_activation_all * 2 / NUM_BYTES_IN_GIGABYTE ) # only support fp16 total_memory = weight_grad_optim_memory + activation_memory print( f"Theoretical memory footprints: weight and optimizer={weight_grad_optim_memory:.2f} GB, " f"activation={activation_memory:.2f} GB, total={total_memory:.2f} GB\n" ) # 生成与 estimate_from_config 相同格式的聚合报告 model_breakdown_concat = "\n\n".join( [f"--- vpp_chunk {i} ---\n{str(m)}" for i, m in enumerate(model)] ) report = { "pp_rank": pp_rank, "parameters_b": num_parameter_this_shard_all / 1e9, "activation_b": num_activation_all / 1e9, "weight_grad_gb": round(weight_grad_memory, 2), "weight_grad_optim_gb": round(weight_grad_optim_memory, 2), "activation_gb": round(activation_memory, 2), "total_gb": round(total_memory, 2), "model_breakdown": model_breakdown_concat, "details": None, } return output_shape, report if __name__ == "__main__": initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True) import ipdb with ipdb.launch_ipdb_on_exception(): args = get_args() report_memory_usage(args)