File size: 19,658 Bytes
d2006b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 |
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Pretrain GPT."""
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore")
import inspect
import os
from contextlib import nullcontext
from functools import partial
from typing import List, Optional, Tuple, Union
import torch
from megatron.core import mpu
from megatron.core.datasets.blended_megatron_dataset_builder import (
BlendedMegatronDatasetBuilder,
)
from megatron.core.datasets.gpt_dataset import (
GPTDataset,
GPTDatasetConfig,
MockGPTDataset,
)
from megatron.core.datasets.utils import get_blend_from_list
from megatron.core.enums import ModelType
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.transformer.spec_utils import import_module
from megatron.core.utils import StragglerDetector
from megatron.training import (
get_args,
get_timers,
get_tokenizer,
pretrain,
print_rank_0,
)
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.initialize import initialize_megatron
from megatron.training.utils import get_batch_on_this_cp_rank, get_batch_on_this_tp_rank
from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from moe_mem_estimator.base import (
get_pipeline_model_parallel_rank,
get_pipeline_model_parallel_world_size,
get_virtual_pipeline_model_parallel_world_size,
is_pipeline_first_stage,
is_pipeline_last_stage,
set_global_config,
set_pipeline_model_parallel_rank,
)
from moe_mem_estimator.gpt_model import GPTModel
from moe_mem_estimator.layers import MLASelfAttention, MoELayer
torch.distributed.get_rank = lambda: 0
torch.cuda.get_device_capability = lambda: [8]
def estimate_from_config(config, args):
"""
Estimate memory usage from a given config and args, instead of global state.
Now supports virtual pipeline model parallelism for more accurate results.
"""
args.moe_grouped_gemm = True
patch_parallel_states()
if config is None:
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
input_shape = [args.micro_batch_size, args.seq_length]
set_global_config(config)
print(config)
# return
cli_reports = []
if config.pipeline_model_parallel_size > 1:
for pp_rank in range(config.pipeline_model_parallel_size):
set_pipeline_model_parallel_rank(pp_rank)
print(
f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------"
)
input_shape, rpt = report_memory_usage_one_pp_rank(
input_shape, args, config, pp_rank, config.pipeline_model_parallel_size
)
cli_reports.append(rpt)
else:
set_pipeline_model_parallel_rank(0)
_, rpt = report_memory_usage_one_pp_rank(input_shape, args, config)
cli_reports.append(rpt)
aggregated_reports: list[dict] = cli_reports
# θΏε (θεεη pp ζ₯εε葨, ε
¨ι raw chunk ε葨)
return aggregated_reports, cli_reports
def _get_transformer_layer_spec(use_te, config):
"""Get transformer layer specification based on configuration.
Args:
use_te (bool): Whether to use Transformer Engine
args: Training arguments
config: Model configuration
Returns:
transformer_layer_spec: The transformer layer specification
"""
if use_te:
return get_gpt_layer_with_transformer_engine_spec(
config.num_moe_experts,
config.moe_grouped_gemm,
config.qk_layernorm,
config.multi_latent_attention,
config.fp8,
)
else:
return get_gpt_layer_local_spec(
config.num_moe_experts,
config.moe_grouped_gemm,
config.qk_layernorm,
config.multi_latent_attention,
)
def model_provider(
args, config, pre_process=True, post_process=True, vp_stage: Optional[int] = None
) -> GPTModel:
use_te = True
if args.num_experts:
# Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(
config,
use_transformer_engine=use_te,
normalization="LayerNorm",
qk_l2_norm=False,
vp_stage=vp_stage,
)
else:
# Define the decoder layer spec
transformer_layer_spec = _get_transformer_layer_spec(use_te, config)
mtp_block_spec = None
# TODO fp8
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=getattr(config, "fp16_lm_cross_entropy", False),
parallel_output=True,
share_embeddings_and_output_weights=False,
position_embedding_type="rope",
rotary_percent=getattr(args, "rotary_percent", 1.0),
rotary_base=getattr(args, "rotary_base", 10000),
rope_scaling=getattr(config, "use_rope_scaling", False),
mtp_block_spec=mtp_block_spec,
vp_stage=vp_stage,
)
return model
def get_model(
model_provider_func, args, config, model_type=ModelType.encoder_or_decoder
):
"""Build the model."""
# args = get_args()
# args.model_type = model_type
# Build model.
if not getattr(args, "virtual_pipeline_model_parallel_size", None):
args.virtual_pipeline_model_parallel_size = None
if config.pipeline_model_parallel_layout:
args.virtual_pipeline_model_parallel_size = (
config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size
)
config.virtual_pipeline_model_parallel_size = (
config.pipeline_model_parallel_layout.virtual_pipeline_model_parallel_size
)
def build_model():
if (
get_pipeline_model_parallel_world_size() > 1
and args.virtual_pipeline_model_parallel_size is not None
):
if model_type == ModelType.encoder_and_decoder:
assert (
config.encoder_pipeline_model_parallel_size == 0
), "Interleaved schedule not supported for model with encoder on separate PP rank"
model = []
for i in range(args.virtual_pipeline_model_parallel_size):
# Set pre_process and post_process only after virtual rank is set.
pre_process = is_pipeline_first_stage(ignore_virtual=False, vp_stage=i)
post_process = is_pipeline_last_stage(ignore_virtual=False, vp_stage=i)
this_model = model_provider_func(
args,
config,
pre_process=pre_process,
post_process=post_process,
vp_stage=i,
)
this_model.model_type = model_type
this_model.vp_stage = i
model.append(this_model)
else:
pre_process = is_pipeline_first_stage()
post_process = is_pipeline_last_stage()
if model_type == ModelType.encoder_and_decoder:
if get_pipeline_model_parallel_world_size() > 1:
rank = get_pipeline_model_parallel_rank()
first_decoder_rank = config.encoder_pipeline_model_parallel_size
world_size = get_pipeline_model_parallel_world_size()
pre_process = rank == 0 or rank == first_decoder_rank
post_process = (rank == (first_decoder_rank - 1)) or (
rank == (world_size - 1)
)
model = model_provider_func(
args,
config,
pre_process=pre_process,
post_process=post_process,
)
else:
model = model_provider_func(
args, config, pre_process=pre_process, post_process=post_process
)
model.model_type = model_type
return model
model = build_model()
if not isinstance(model, list):
model = [model]
return model
NUM_BYTES_IN_MEGABYTE = 1024 * 1024
NUM_BYTES_IN_GIGABYTE = 1024 * 1024 * 1024
def patch_parallel_states():
from megatron.core import parallel_state
parallel_state.is_pipeline_first_stage = is_pipeline_first_stage
parallel_state.is_pipeline_last_stage = is_pipeline_last_stage
parallel_state.get_pipeline_model_parallel_rank = get_pipeline_model_parallel_rank
parallel_state.get_pipeline_model_parallel_world_size = (
get_pipeline_model_parallel_world_size
)
parallel_state.get_virtual_pipeline_model_parallel_world_size = (
get_virtual_pipeline_model_parallel_world_size
)
parallel_state.is_inside_encoder = lambda: False
parallel_state.get_pipeline_model_parallel_decoder_start = lambda: 0
def report_memory_usage(args, config=None):
args.moe_grouped_gemm = True
patch_parallel_states()
if config is None:
if args.yaml_cfg is not None:
config = core_transformer_config_from_yaml(args, "language_model")
else:
config = core_transformer_config_from_args(args)
input_shape = [args.micro_batch_size, args.seq_length]
set_global_config(config)
cli_reports = []
if config.pipeline_model_parallel_size > 1:
for pp_rank in range(config.pipeline_model_parallel_size):
set_pipeline_model_parallel_rank(pp_rank)
print(
f"\n------------------------------[Pipeline_Parallelism_Rank={pp_rank}]------------------------------"
)
input_shape, rpt = report_memory_usage_one_pp_rank(
input_shape, args, config, pp_rank, config.pipeline_model_parallel_size
)
cli_reports.append(rpt)
else:
set_pipeline_model_parallel_rank(0)
_, rpt = report_memory_usage_one_pp_rank(input_shape, args, config)
cli_reports.append(rpt)
# Optionally pretty print summary
print("\n===== Summary (per PP rank) =====")
for r in cli_reports:
print(
f"PP{r['pp_rank']} total {r['total_gb']} GB (weight_grad {r['weight_grad_gb']} GB weight_grad_optim {r['weight_grad_optim_gb']} GB act {r['activation_gb']} GB)"
)
def report_memory_usage_one_pp_rank(
input_shape: list[int], args, config, pp_rank=0, pp_size=1
) -> tuple[list[int], dict]:
print(f"{input_shape=}")
model: list[GPTModel] = get_model(model_provider, args, config)
num_parameter_this_shard_all = 0
num_parameter_this_shard_sparse_all = 0
num_activation_all = 0
output_shape = input_shape
for vpp_rank, one_chunk in enumerate(model):
num_parameter_this_shard = one_chunk.num_parameter()
num_activation = one_chunk.num_activation(output_shape)
output_shape = one_chunk.mock_forward(output_shape)
print(f"{output_shape=}")
num_parameter_this_shard_sparse = 0
for layer in one_chunk.decoder.layers.modules:
if isinstance(layer.mlp, MoELayer):
num_parameter_this_shard_sparse += layer.mlp.num_parameter()
if (
"shared_experts" in layer.mlp.__dir__()
and layer.mlp.shared_experts is not None
):
num_parameter_this_shard_sparse -= (
layer.mlp.shared_experts.num_parameter()
)
num_activation_this_shard_mlp = sum(
[m.mlp.num_activation() for m in one_chunk.decoder.layers.modules]
)
if len(model) > 1:
if vpp_rank >= 1 and vpp_rank < len(model) - 1:
num_microbatch_this_pp_rank = pp_size
elif vpp_rank == 0:
num_microbatch_this_pp_rank = pp_size + max(
(pp_size - pp_rank) * 2 - 1 - pp_size, 0
)
elif vpp_rank == len(model) - 1:
num_microbatch_this_pp_rank = min((pp_size - pp_rank) * 2 + 1, pp_size)
else:
num_microbatch_this_pp_rank = pp_size - pp_rank
num_parameter_this_shard_sparse = 0
for layer in one_chunk.decoder.layers.modules:
if isinstance(layer.mlp, MoELayer):
num_parameter_this_shard_sparse += layer.mlp.num_parameter()
if (
"shared_experts" in layer.mlp.__dir__()
and layer.mlp.shared_experts is not None
):
num_parameter_this_shard_sparse -= (
layer.mlp.shared_experts.num_parameter()
)
one_chunk.__repr__()
print(one_chunk)
print(
f"Number of parameters in every GPU in billions: "
f"{num_parameter_this_shard / 10**9: .2f} where mlp part is {num_parameter_this_shard_sparse / 10**9: .2f}"
)
num_parameter_this_shard_all += num_parameter_this_shard
num_parameter_this_shard_sparse_all += num_parameter_this_shard_sparse
# recompute
if config.recompute_granularity == "full":
recompute_num_layers = config.recompute_num_layers
num_layers = one_chunk.num_layers
common_act = (
one_chunk.num_act_pre
+ one_chunk.num_act_between_layers
* num_layers
* num_microbatch_this_pp_rank
) # recompute with pipeline parallel
info = "With this recomputing setting, the number of activation achieve peak when "
if config.recompute_method == "block":
num_layers_with_loss = num_layers - recompute_num_layers
if num_layers_with_loss == 0:
peak1 = common_act + one_chunk.num_act_post
peak2 = common_act + one_chunk.num_act_per_layer
if peak1 > peak2:
info += "calculating loss"
else:
info += "back-propogating loss"
num_activation = max(peak1, peak2)
else:
info += f"calculating loss with {num_layers_with_loss} non-recompute layers"
num_activation = (
common_act
+ one_chunk.num_act_post
+ one_chunk.num_act_per_layer
* num_layers_with_loss
* num_microbatch_this_pp_rank
)
elif config.recompute_method == "uniform":
peak1 = common_act + one_chunk.num_act_post
peak2 = (
(common_act + one_chunk.num_act_per_layer)
if vpp_rank == 0
else (common_act)
)
if peak1 > peak2:
info += "calculating loss"
else:
info += f"back-propogating loss recomputing every {recompute_num_layers} layers"
num_activation = max(peak1, peak2)
if len(one_chunk.decoder.layers.modules) > 0 and isinstance(
one_chunk.decoder.layers.modules[0].self_attention, MLASelfAttention
): # MLA recompute achieve peak at backward
num_activation += one_chunk.decoder.layers.modules[
0
].self_attention.core_attention.num_activation()
print(info)
else:
num_activation = (
num_activation - one_chunk.num_act_post
) * num_microbatch_this_pp_rank + one_chunk.num_act_post
# CP
num_activation = num_activation / config.context_parallel_size
if pp_size == 1:
print(
f"Number of activation in every GPU in billions: "
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
)
else:
print(
f"Number of activation per microbatch in every GPU in billions: "
f"{num_activation / 10**9: .2f} where mlp part is {num_activation_this_shard_mlp / 10**9: .2f}"
f", {num_microbatch_this_pp_rank=} {vpp_rank=}"
)
num_activation_all += num_activation
num_bytes_per_parameter = (
18
if not args.use_distributed_optimizer
else 6 + (12 / args.data_parallel_size / config.context_parallel_size)
)
if config.expert_model_parallel_size * config.expert_tensor_parallel_size > 1:
num_bytes_per_parameter_dense = num_bytes_per_parameter
num_bytes_per_parameter_moe = (
18
if not args.use_distributed_optimizer
else 6
+ (
12
/ (
args.world_size
/ config.pipeline_model_parallel_size
/ config.expert_model_parallel_size
/ config.expert_tensor_parallel_size
)
)
)
print(f"{num_bytes_per_parameter_dense=} {num_bytes_per_parameter_moe=}")
weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE
weight_grad_optim_memory = (
(num_parameter_this_shard_all - num_parameter_this_shard_sparse_all)
* num_bytes_per_parameter_dense
+ num_parameter_this_shard_sparse_all * num_bytes_per_parameter_moe
) / NUM_BYTES_IN_GIGABYTE
else:
print(f"{num_bytes_per_parameter=}")
weight_grad_memory = num_parameter_this_shard_all * 6 / NUM_BYTES_IN_GIGABYTE
weight_grad_optim_memory = (
num_parameter_this_shard_all
* num_bytes_per_parameter
/ NUM_BYTES_IN_GIGABYTE
)
activation_memory = (
num_activation_all * 2 / NUM_BYTES_IN_GIGABYTE
) # only support fp16
total_memory = weight_grad_optim_memory + activation_memory
print(
f"Theoretical memory footprints: weight and optimizer={weight_grad_optim_memory:.2f} GB, "
f"activation={activation_memory:.2f} GB, total={total_memory:.2f} GB\n"
)
# ηζδΈ estimate_from_config ηΈεζ ΌεΌηθεζ₯ε
model_breakdown_concat = "\n\n".join(
[f"--- vpp_chunk {i} ---\n{str(m)}" for i, m in enumerate(model)]
)
report = {
"pp_rank": pp_rank,
"parameters_b": num_parameter_this_shard_all / 1e9,
"activation_b": num_activation_all / 1e9,
"weight_grad_gb": round(weight_grad_memory, 2),
"weight_grad_optim_gb": round(weight_grad_optim_memory, 2),
"activation_gb": round(activation_memory, 2),
"total_gb": round(total_memory, 2),
"model_breakdown": model_breakdown_concat,
"details": None,
}
return output_shape, report
if __name__ == "__main__":
initialize_megatron(allow_no_cuda=True, skip_mpu_initialization=True)
import ipdb
with ipdb.launch_ipdb_on_exception():
args = get_args()
report_memory_usage(args)
|