Victarry commited on
Commit
a5a3887
·
1 Parent(s): 8854c6a

Update args logic.

Browse files
Files changed (1) hide show
  1. pipeline.py +11 -21
pipeline.py CHANGED
@@ -280,12 +280,12 @@ def parse_args():
280
  "--num-stages",
281
  "-s",
282
  type=int,
283
- default=4,
284
  help="Number of pipeline stages (devices)",
285
  )
286
 
287
  parser.add_argument(
288
- "--num-batches", "-b", type=int, default=10, help="Number of micro-batches"
289
  )
290
 
291
  # Forward and backward times
@@ -369,6 +369,15 @@ def main():
369
  backward_times = None
370
  output_file = "pipeline_1f1b.png"
371
  p2p_time = 0.0
 
 
 
 
 
 
 
 
 
372
  # Read from config file if provided
373
  if args.config:
374
  try:
@@ -387,25 +396,6 @@ def main():
387
  print(f"Error reading config file: {str(e)}")
388
  print("Falling back to command line arguments or defaults")
389
 
390
- # Command line arguments override config file
391
- if args.num_stages:
392
- num_stages = args.num_stages
393
-
394
- if args.num_batches:
395
- num_batches = args.num_batches
396
-
397
- if args.forward_times:
398
- forward_times = args.forward_times
399
-
400
- if args.backward_times:
401
- backward_times = args.backward_times
402
-
403
- if args.output:
404
- output_file = args.output
405
-
406
- if args.p2p_time:
407
- p2p_time = args.p2p_time
408
-
409
  # Validate inputs
410
  if forward_times is None:
411
  forward_times = [1.0] * num_stages
 
280
  "--num-stages",
281
  "-s",
282
  type=int,
283
+ default=0,
284
  help="Number of pipeline stages (devices)",
285
  )
286
 
287
  parser.add_argument(
288
+ "--num-batches", "-b", type=int, default=0, help="Number of micro-batches"
289
  )
290
 
291
  # Forward and backward times
 
369
  backward_times = None
370
  output_file = "pipeline_1f1b.png"
371
  p2p_time = 0.0
372
+
373
+ # Command line arguments override config file
374
+ num_stages = args.num_stages
375
+ num_batches = args.num_batches
376
+ forward_times = args.forward_times
377
+ backward_times = args.backward_times
378
+ output_file = args.output
379
+ p2p_time = args.p2p_time
380
+
381
  # Read from config file if provided
382
  if args.config:
383
  try:
 
396
  print(f"Error reading config file: {str(e)}")
397
  print("Falling back to command line arguments or defaults")
398
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
399
  # Validate inputs
400
  if forward_times is None:
401
  forward_times = [1.0] * num_stages