Victarry commited on
Commit
fc51155
·
1 Parent(s): 16ed969
Files changed (3) hide show
  1. main.py +1 -1
  2. src/execution_model.py +2 -2
  3. src/visualizer.py +0 -33
main.py CHANGED
@@ -29,7 +29,7 @@ def run_1f1b(cfg: DictConfig) -> None:
29
  num_batches=cfg.num_batches,
30
  p2p_latency=cfg.p2p_latency,
31
  op_times=op_times,
32
- placement_strategy="1f1b"
33
  )
34
  schedule = generate_1f1b_schedule(schedule_config)
35
  executor = ScheduleExecutor(schedule)
 
29
  num_batches=cfg.num_batches,
30
  p2p_latency=cfg.p2p_latency,
31
  op_times=op_times,
32
+ placement_strategy="standard"
33
  )
34
  schedule = generate_1f1b_schedule(schedule_config)
35
  executor = ScheduleExecutor(schedule)
src/execution_model.py CHANGED
@@ -35,7 +35,7 @@ class ScheduleConfig:
35
  num_stages: int,
36
  num_batches: int,
37
  p2p_latency: float = 0.0,
38
- placement_strategy: str = "normal",
39
  op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
40
  ):
41
  self.num_devices = num_devices
@@ -79,7 +79,7 @@ class ScheduleConfig:
79
  )
80
 
81
  def init_device_to_stages(self):
82
- if self.placement_strategy == "normal":
83
  # Evenly distributed
84
  stages_per_device = self.num_stages // self.num_devices
85
  self.device_to_stages = defaultdict(list)
 
35
  num_stages: int,
36
  num_batches: int,
37
  p2p_latency: float = 0.0,
38
+ placement_strategy: str = "standard",
39
  op_times: Optional[Dict[str, Union[float, Dict[int, float]]]] = None,
40
  ):
41
  self.num_devices = num_devices
 
79
  )
80
 
81
  def init_device_to_stages(self):
82
+ if self.placement_strategy == "standard":
83
  # Evenly distributed
84
  stages_per_device = self.num_stages // self.num_devices
85
  self.device_to_stages = defaultdict(list)
src/visualizer.py CHANGED
@@ -298,18 +298,6 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
298
  ]),
299
  ], className="config-section"),
300
 
301
- html.Button("Download Image", id="btn-download",
302
- style={
303
- 'marginTop': '20px',
304
- 'padding': '10px',
305
- 'backgroundColor': '#007BFF',
306
- 'color': 'white',
307
- 'border': 'none',
308
- 'borderRadius': '5px',
309
- 'cursor': 'pointer'
310
- }),
311
-
312
- dcc.Download(id="download-image"),
313
  ], style={'margin': '20px'}),
314
 
315
  html.Div(id="graph-container", children=[]),
@@ -329,27 +317,6 @@ def create_dash_app(schedule: Schedule, schedule_type="1f1b"):
329
  # Create the figure when the app loads
330
  return create_pipeline_figure(schedule_data, show_progress=True)
331
 
332
- @app.callback(
333
- Output("download-image", "data"),
334
- Input("btn-download", "n_clicks"),
335
- prevent_initial_call=True,
336
- )
337
- def download_image(n_clicks):
338
- # Generate the figure for download
339
- fig = create_pipeline_figure(schedule_data, show_progress=True)
340
-
341
- # Convert to base64 image
342
- img_bytes = fig.to_image(format="png", width=1600, height=1000, scale=2)
343
- img_base64 = base64.b64encode(img_bytes).decode('ascii')
344
-
345
- # Return the download data
346
- return dict(
347
- content=img_base64,
348
- filename=f"pipeline_visualization_{schedule_type}.png",
349
- type="image/png",
350
- base64=True
351
- )
352
-
353
  return app
354
 
355
 
 
298
  ]),
299
  ], className="config-section"),
300
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  ], style={'margin': '20px'}),
302
 
303
  html.Div(id="graph-container", children=[]),
 
317
  # Create the figure when the app loads
318
  return create_pipeline_figure(schedule_data, show_progress=True)
319
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
  return app
321
 
322