Victarry commited on
Commit
eabf66f
·
1 Parent(s): a885118

Update visualizer.

Browse files
Files changed (1) hide show
  1. src/visualizer.py +68 -66
src/visualizer.py CHANGED
@@ -70,56 +70,47 @@ def convert_schedule_to_visualization_format(schedule: Schedule):
70
  # Cache the color calculation as it's repeatedly called with the same parameters
71
  @lru_cache(maxsize=128)
72
  def get_color(op_type: str, stage_id: int, num_devices: int):
73
- # A more harmonious blue palette with better progression for forward operations
74
  forward_colors = [
75
- "#0039e6", # Rich navy
76
- "#1a53ff", # Deep blue
77
- "#4d79ff", # Strong blue
78
- "#5c88f2", # Periwinkle blue
79
- "#7094db", # Steel blue
80
- "#809fff", # Medium blue
81
- "#99b3e6", # Pale blue
82
- "#b3c6ff", # Light blue
83
  ]
84
 
85
- # Orange palette for backward operations
86
  backward_colors = [
87
- "#ff8000", # Deep orange
88
- "#ff9933", # Strong orange
89
- "#ffad5c", # Medium orange
90
- "#ffc285", # Light orange
91
- "#ffd699", # Light amber
92
- "#ffd6ad", # Pale orange
93
- "#ffe0c2", # Very pale orange
94
- "#fff0e0", # Lightest orange
95
  ]
96
 
97
- # Improved teal/turquoise palette with better progression for backward_D operations
98
  backward_d_colors = [
99
- "#80ffff", # Light cyan
100
- "#00cccc", # Teal
101
- "#00e6e6", # Bright teal
102
- "#33ffff", # Cyan
103
- "#00b3b3", # Medium teal
 
 
 
104
  "#008080", # Dark teal
105
- "#00e6cc", # Turquoise
106
- "#4ddbbd", # Aqua
107
- "#80d4c8", # Pale teal
108
- "#b3e6e0", # Ice
109
  ]
110
 
111
- # Improved green palette with better progression for backward_W operations
112
  backward_w_colors = [
113
- "#00cc66", # Medium green
114
- "#00e673", # Bright green
115
- "#33ff99", # Mint green
116
- "#80ffbf", # Light green
117
- "#009933", # Forest green
118
- "#006622", # Dark green
119
- "#33cc33", # True green
120
- "#66cc66", # Sage green
121
- "#99cc99", # Pale green
122
- "#c6e6c6", # Pastel green
123
  ]
124
 
125
  virtual_stage = stage_id // num_devices
@@ -130,11 +121,11 @@ def get_color(op_type: str, stage_id: int, num_devices: int):
130
  if op_type == "forward":
131
  return forward_colors[color_index]
132
  elif op_type == "backward":
133
- return backward_colors[color_index]
134
  elif op_type == "backward_D":
135
- return backward_d_colors[color_index]
136
  elif op_type == "backward_W":
137
- return backward_w_colors[color_index]
138
  else:
139
  raise ValueError(f"Invalid operation type: {op_type}")
140
 
@@ -163,6 +154,15 @@ def create_pipeline_figure(
163
  end_time = task["start_time"] + task["duration"]
164
  if end_time > max_time:
165
  max_time = end_time
 
 
 
 
 
 
 
 
 
166
 
167
  # Create a figure
168
  fig = go.Figure()
@@ -251,22 +251,23 @@ def create_pipeline_figure(
251
  )
252
  )
253
 
254
- # Add batch number text for this sub-operation
255
- # Determine text color based on background color
256
- if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
257
- text_color = "black"
258
- else:
259
- text_color = "white"
260
-
261
- annotations.append(
262
- dict(
263
- x=start_time + duration / 2,
264
- y=sub_y_center,
265
- text=f"{sub_op['batch']}",
266
- showarrow=False,
267
- font=dict(color=text_color, size=12, family="Arial, bold"),
 
 
268
  )
269
- )
270
  else:
271
  # Regular (non-overlapped) operation
272
  # Determine task color and text color
@@ -305,16 +306,17 @@ def create_pipeline_figure(
305
  )
306
  )
307
 
308
- # Add batch number text
309
- annotations.append(
310
- dict(
311
- x=start_time + duration / 2,
312
- y=y_pos,
313
- text=f"{task['batch']}",
314
- showarrow=False,
315
- font=dict(color=text_color, size=12, family="Arial, bold"),
 
 
316
  )
317
- )
318
 
319
  # Prepare hover data
320
  hover_text = (
 
70
  # Cache the color calculation as it's repeatedly called with the same parameters
71
  @lru_cache(maxsize=128)
72
  def get_color(op_type: str, stage_id: int, num_devices: int):
73
+ # A more harmonious blue palette with low saturation and high brightness
74
  forward_colors = [
75
+ "#0a5aff", # Intense blue
76
+ "#4c88ff", # Blue (deeper)
77
+ "#7aa7ff", # Medium blue
78
+ "#a8c5ff", # Soft blue
79
+ "#d6e4ff", # Very light blue
 
 
 
80
  ]
81
 
82
+ # Orange palette for backward operations with low saturation and high brightness
83
  backward_colors = [
84
+ "#f47b00", # Intense orange
85
+ "#ffa952", # Orange
86
+ "#ffc78e", # Light orange
87
+ "#ffe6cc", # Very light orange
 
 
 
 
88
  ]
89
 
90
+ # Improved teal/turquoise palette with low saturation and high brightness
91
  backward_d_colors = [
92
+ "#ccffff", # Very light cyan
93
+ "#b3ffff", # Pale cyan
94
+ "#99ffff", # Light cyan
95
+ "#80ffff", # Cyan
96
+ "#66e6e6", # Soft teal
97
+ "#4dcccc", # Light teal
98
+ "#33b3b3", # Teal
99
+ "#009999", # Medium teal
100
  "#008080", # Dark teal
 
 
 
 
101
  ]
102
 
103
+ # Improved green palette with low saturation and high brightness
104
  backward_w_colors = [
105
+ "#ccffe6", # Very light mint
106
+ "#b3ffd9", # Pale mint
107
+ "#99ffcc", # Light mint
108
+ "#80ffbf", # Mint green
109
+ "#66e6a6", # Soft green
110
+ "#4dcc8c", # Light green
111
+ "#33b373", # Medium green
112
+ "#009959", # Forest green
113
+ "#008040", # Dark green
 
114
  ]
115
 
116
  virtual_stage = stage_id // num_devices
 
121
  if op_type == "forward":
122
  return forward_colors[color_index]
123
  elif op_type == "backward":
124
+ return backward_colors[color_index % len(backward_colors)]
125
  elif op_type == "backward_D":
126
+ return backward_d_colors[color_index % len(backward_d_colors)]
127
  elif op_type == "backward_W":
128
+ return backward_w_colors[color_index % len(backward_w_colors)]
129
  else:
130
  raise ValueError(f"Invalid operation type: {op_type}")
131
 
 
154
  end_time = task["start_time"] + task["duration"]
155
  if end_time > max_time:
156
  max_time = end_time
157
+
158
+ # Determine maximum batch number to decide whether to show text labels
159
+ max_batch = 0
160
+ for device in schedule_data:
161
+ for task in schedule_data[device]:
162
+ max_batch = max(max_batch, task["batch"])
163
+
164
+ # Flag to determine whether to show text labels
165
+ show_text_labels = max_batch <= 16
166
 
167
  # Create a figure
168
  fig = go.Figure()
 
251
  )
252
  )
253
 
254
+ # Add batch number text for this sub-operation only if show_text_labels is True
255
+ if show_text_labels:
256
+ # Determine text color based on background color
257
+ if sub_op["type"] in ["backward", "backward_D", "backward_W"]:
258
+ text_color = "black"
259
+ else:
260
+ text_color = "white"
261
+
262
+ annotations.append(
263
+ dict(
264
+ x=start_time + duration / 2,
265
+ y=sub_y_center,
266
+ text=f"{sub_op['batch']}",
267
+ showarrow=False,
268
+ font=dict(color=text_color, size=12, family="Arial, bold"),
269
+ )
270
  )
 
271
  else:
272
  # Regular (non-overlapped) operation
273
  # Determine task color and text color
 
306
  )
307
  )
308
 
309
+ # Add batch number text only if show_text_labels is True
310
+ if show_text_labels:
311
+ annotations.append(
312
+ dict(
313
+ x=start_time + duration / 2,
314
+ y=y_pos,
315
+ text=f"{task['batch']}",
316
+ showarrow=False,
317
+ font=dict(color=text_color, size=12, family="Arial, bold"),
318
+ )
319
  )
 
320
 
321
  # Prepare hover data
322
  hover_text = (