Spaces:
Running
Running
Update args logic.
Browse files- pipeline.py +11 -21
pipeline.py
CHANGED
@@ -280,12 +280,12 @@ def parse_args():
|
|
280 |
"--num-stages",
|
281 |
"-s",
|
282 |
type=int,
|
283 |
-
default=
|
284 |
help="Number of pipeline stages (devices)",
|
285 |
)
|
286 |
|
287 |
parser.add_argument(
|
288 |
-
"--num-batches", "-b", type=int, default=
|
289 |
)
|
290 |
|
291 |
# Forward and backward times
|
@@ -369,6 +369,15 @@ def main():
|
|
369 |
backward_times = None
|
370 |
output_file = "pipeline_1f1b.png"
|
371 |
p2p_time = 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
372 |
# Read from config file if provided
|
373 |
if args.config:
|
374 |
try:
|
@@ -387,25 +396,6 @@ def main():
|
|
387 |
print(f"Error reading config file: {str(e)}")
|
388 |
print("Falling back to command line arguments or defaults")
|
389 |
|
390 |
-
# Command line arguments override config file
|
391 |
-
if args.num_stages:
|
392 |
-
num_stages = args.num_stages
|
393 |
-
|
394 |
-
if args.num_batches:
|
395 |
-
num_batches = args.num_batches
|
396 |
-
|
397 |
-
if args.forward_times:
|
398 |
-
forward_times = args.forward_times
|
399 |
-
|
400 |
-
if args.backward_times:
|
401 |
-
backward_times = args.backward_times
|
402 |
-
|
403 |
-
if args.output:
|
404 |
-
output_file = args.output
|
405 |
-
|
406 |
-
if args.p2p_time:
|
407 |
-
p2p_time = args.p2p_time
|
408 |
-
|
409 |
# Validate inputs
|
410 |
if forward_times is None:
|
411 |
forward_times = [1.0] * num_stages
|
|
|
280 |
"--num-stages",
|
281 |
"-s",
|
282 |
type=int,
|
283 |
+
default=0,
|
284 |
help="Number of pipeline stages (devices)",
|
285 |
)
|
286 |
|
287 |
parser.add_argument(
|
288 |
+
"--num-batches", "-b", type=int, default=0, help="Number of micro-batches"
|
289 |
)
|
290 |
|
291 |
# Forward and backward times
|
|
|
369 |
backward_times = None
|
370 |
output_file = "pipeline_1f1b.png"
|
371 |
p2p_time = 0.0
|
372 |
+
|
373 |
+
# Command line arguments override config file
|
374 |
+
num_stages = args.num_stages
|
375 |
+
num_batches = args.num_batches
|
376 |
+
forward_times = args.forward_times
|
377 |
+
backward_times = args.backward_times
|
378 |
+
output_file = args.output
|
379 |
+
p2p_time = args.p2p_time
|
380 |
+
|
381 |
# Read from config file if provided
|
382 |
if args.config:
|
383 |
try:
|
|
|
396 |
print(f"Error reading config file: {str(e)}")
|
397 |
print("Falling back to command line arguments or defaults")
|
398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
399 |
# Validate inputs
|
400 |
if forward_times is None:
|
401 |
forward_times = [1.0] * num_stages
|