chenmouxiang commited on
Commit
7b9ce4e
Β·
verified Β·
1 Parent(s): 087cecd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +101 -46
app.py CHANGED
@@ -1,91 +1,146 @@
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import numpy as np
 
 
4
 
5
  # Predefined hyperparameter sets
6
  PARAM_SETS = {
7
- "Set A": {"param1": 0.1, "param2": 0.01, "param3": 100, "param4": 50},
8
- "Set B": {"param1": 0.2, "param2": 0.02, "param3": 200, "param4": 100}
9
  }
10
 
11
- def generate_plot(param1, param2, param3, param4):
12
- """Generate visualization based on hyperparameters"""
13
- plt.figure(figsize=(10, 6))
14
- x = np.linspace(0, 10, int(param3))
15
- y = np.sin(x * param1) * np.cos(x * param2) * param4
16
- plt.plot(x, y)
17
- plt.title(f'Parameter Visualization (p1={param1}, p2={param2}, p3={param3}, p4={param4})')
18
- plt.grid(True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  return plt
20
 
21
- def process_inputs(param_set, custom_param1, custom_param2, custom_param3, custom_param4,
22
- input1, input2):
 
 
 
 
 
 
 
 
 
 
 
23
  """Process inputs and return results"""
24
- # Determine which parameter set to use
25
- if param_set in PARAM_SETS:
26
- params = PARAM_SETS[param_set]
27
- p1, p2, p3, p4 = params.values()
28
- else:
29
- p1, p2, p3, p4 = custom_param1, custom_param2, custom_param3, custom_param4
30
-
31
- # Generate plot
32
- plot = generate_plot(p1, p2, p3, p4)
33
-
34
- # Calculate result (example calculation)
35
- result = (input1 * p1 + input2 * p2) * (p3 + p4)
36
 
37
- return plot, result
38
 
39
  # Create interface
 
 
 
 
 
 
 
 
 
 
40
  with gr.Blocks() as demo:
41
- gr.Markdown("# Hyperparameter Calculation and Visualization System")
42
 
43
  with gr.Row():
44
  with gr.Column():
 
 
 
 
 
 
 
45
  # Hyperparameter selection section
46
  param_set = gr.Dropdown(
47
  choices=["Custom"] + list(PARAM_SETS.keys()),
48
- value="Custom",
49
- label="Select Hyperparameter Set"
50
  )
51
 
52
  # Custom parameter inputs
53
- custom_param1 = gr.Number(value=0.1, label="Parameter 1 (Learning Rate)")
54
- custom_param2 = gr.Number(value=0.01, label="Parameter 2 (Weight Decay)")
55
- custom_param3 = gr.Number(value=100, label="Parameter 3 (Iterations)")
56
- custom_param4 = gr.Number(value=50, label="Parameter 4 (Batch Size)")
57
-
58
- # Input values
59
- input1 = gr.Number(value=1.0, label="Input Value 1")
60
- input2 = gr.Number(value=1.0, label="Input Value 2")
61
 
62
- submit_btn = gr.Button("Calculate")
63
 
 
 
64
  with gr.Column():
65
  # Output section
66
- plot_output = gr.Plot(label="Parameter Visualization")
67
- result_output = gr.Number(label="Calculation Result")
 
68
 
69
  # Auto-fill parameters when selecting predefined sets
70
  def update_params(param_set):
71
  if param_set in PARAM_SETS:
72
  params = PARAM_SETS[param_set]
73
- return [params["param1"], params["param2"], params["param3"], params["param4"]]
74
  return [gr.skip(), gr.skip(), gr.skip(), gr.skip()]
75
 
76
  param_set.change(
77
  update_params,
78
  inputs=[param_set],
79
- outputs=[custom_param1, custom_param2, custom_param3, custom_param4]
80
  )
81
 
82
  # Submit button event
83
- submit_btn.click(
84
  process_inputs,
85
- inputs=[param_set, custom_param1, custom_param2, custom_param3, custom_param4,
86
- input1, input2],
87
  outputs=[plot_output, result_output]
88
  )
89
 
90
- # Launch application
91
- demo.launch()
 
1
  import gradio as gr
2
  import matplotlib.pyplot as plt
3
  import numpy as np
4
+ import math
5
+ from matplotlib.ticker import FuncFormatter
6
 
7
  # Predefined hyperparameter sets
8
  PARAM_SETS = {
9
+ "Stack-V2-Python": {"E": 0.69123678, "A": 0.01130616 * 1e9, "k": 0.393463, "alpha": 0.18937067},
10
+ "Pile": {"E": 1.28254036, "A": 0.2035367 * 1e9, "k": 0.33027934, "alpha": 0.19479807}
11
  }
12
 
13
+ def pred_loss(E, A, k, alpha, n, p):
14
+ return E + (A / (n * (1 + np.log(p) * k))) ** alpha
15
+
16
+ def generate_plot(E, A, k, alpha):
17
+ plt.clf()
18
+ colors = ['#2B83BA', '#7BB7D6', '#ED7D5F', '#D7191C']
19
+ ax = plt.gca()
20
+ for i, p in enumerate([1, 2, 4, 8]):
21
+ x_plot = np.linspace(535813376 * 0.9, 4353203200 * 1.1, 100)
22
+ y_plot = pred_loss(E, A, k, alpha, x_plot, p)
23
+ ax.plot(x_plot, y_plot, marker=None, markersize=1, linewidth=3, color=colors[int(math.log(p, 2))], label=f"$P={p}$")
24
+
25
+ ax.legend(fontsize=12)
26
+ # ax.set_xscale("log")
27
+ # ax.set_yscale("log")
28
+
29
+ def billions(x, pos):
30
+ if x < 1e9:
31
+ result = ""
32
+ else:
33
+ result = f'{x * 1e-9:.1f}B'
34
+ return result
35
+
36
+ ax.xaxis.set_major_formatter(FuncFormatter(billions))
37
+ ax.xaxis.set_minor_formatter(FuncFormatter(billions))
38
+ ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.2f}"))
39
+ ax.yaxis.set_minor_formatter(FuncFormatter(lambda x, pos: f"{x:.2f}"))
40
+ ax.set_xlim(535813376 * 0.9, 4353203200 * 1.1)
41
+ ax.set_ylim(ax.get_ylim()[0] * 1, ax.get_ylim()[1] * 1.01)
42
+
43
+ ax.text(0.03, 0.03, f"$E={E}$\n$A={A}$\n$k={k}$\n$\\alpha={alpha}$", transform=ax.transAxes, fontsize=10, verticalalignment='bottom', multialignment='left')
44
+
45
+ ax.spines['top'].set_visible(False)
46
+ ax.spines['right'].set_visible(False)
47
+
48
+ ax.set_xlabel('Parameters (Non-Embedding)', fontsize=12)
49
+ ax.set_ylabel(f'Loss', fontsize=12)
50
  return plt
51
 
52
+
53
+ OUTPUT_TEMPLATE = """Loss for a {n}B model when P={p} is: **{loss}**. It is equivalant to:
54
+
55
+ - A {n1}B model with P=1;
56
+ - A {n2}B model with P=2;
57
+ - A {n4}B model with P=4;
58
+ - A {n8}B model with P=8;
59
+
60
+ Note: The equivalent parameters are for reference only. In some reasoning tasks, scaling the parallel streams will obtain more performance gains than the loss benefits!
61
+
62
+ Enjoy it! 😊"""
63
+
64
+ def process_inputs(E, A, k, alpha, n, p):
65
  """Process inputs and return results"""
66
+ if n < 1000:
67
+ n = n * 1e9
68
+ plot = generate_plot(E, A, k, alpha)
69
+ loss = pred_loss(E, A, k, alpha, n, p)
70
+
71
+ n1 = n * (k * np.log(p) + 1) / (k * np.log(1) + 1) / 1e9
72
+ n2 = n * (k * np.log(p) + 1) / (k * np.log(2) + 1) / 1e9
73
+ n4 = n * (k * np.log(p) + 1) / (k * np.log(4) + 1) / 1e9
74
+ n8 = n * (k * np.log(p) + 1) / (k * np.log(8) + 1) / 1e9
 
 
 
75
 
76
+ return plot, OUTPUT_TEMPLATE.format(n=round(n / 1e9, 2), p=p, n1=round(n1, 2), n2=round(n2, 2), n4=round(n4, 2), n8=round(n8, 2), loss=loss)
77
 
78
  # Create interface
79
+
80
+ HEAD = """# Parallel Scaling Law Visualization
81
+
82
+ $$
83
+ \\text{Loss}=E+\\left(
84
+ \\frac{A}{\\text{Parameters}\\times (1+k\\log P)}
85
+ \\right)^{\\alpha}
86
+ $$
87
+ """
88
+
89
  with gr.Blocks() as demo:
90
+ gr.Markdown(HEAD)
91
 
92
  with gr.Row():
93
  with gr.Column():
94
+
95
+ # Input values
96
+ N = gr.Number(value=2.8, label="N: Number of Non-Embedding Model Parameters (in Billion)")
97
+ P = gr.Number(value=4, label="P: Number of Parallel Streams")
98
+
99
+ gr.Markdown("---")
100
+
101
  # Hyperparameter selection section
102
  param_set = gr.Dropdown(
103
  choices=["Custom"] + list(PARAM_SETS.keys()),
104
+ value=list(PARAM_SETS.keys())[0],
105
+ label="Select our pre-fitted parameters for two datasets"
106
  )
107
 
108
  # Custom parameter inputs
109
+ param_E = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['E'], label="E")
110
+ param_A = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['A'], label="A")
111
+ param_k = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['k'], label="k")
112
+ param_alpha = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['alpha'], label="alpha")
 
 
 
 
113
 
114
+ submit_btn = gr.Button("Estimate Loss and Equivalant Model Parameters")
115
 
116
+
117
+ plot, output = process_inputs(PARAM_SETS["Stack-V2-Python"]['E'], PARAM_SETS["Stack-V2-Python"]['A'], PARAM_SETS["Stack-V2-Python"]['k'], PARAM_SETS["Stack-V2-Python"]['alpha'], 2.8, 4)
118
  with gr.Column():
119
  # Output section
120
+ plot_output = gr.Plot(label="Scaling Law Curve", value=plot)
121
+ result_output = gr.Markdown(label="Result", value=output)
122
+
123
 
124
  # Auto-fill parameters when selecting predefined sets
125
  def update_params(param_set):
126
  if param_set in PARAM_SETS:
127
  params = PARAM_SETS[param_set]
128
+ return [params["E"], params["A"], params["k"], params["alpha"]]
129
  return [gr.skip(), gr.skip(), gr.skip(), gr.skip()]
130
 
131
  param_set.change(
132
  update_params,
133
  inputs=[param_set],
134
+ outputs=[param_E, param_A, param_k, param_alpha]
135
  )
136
 
137
  # Submit button event
138
+ click_event = submit_btn.click(
139
  process_inputs,
140
+ inputs=[param_E, param_A, param_k, param_alpha,
141
+ N, P],
142
  outputs=[plot_output, result_output]
143
  )
144
 
145
+
146
+ demo.launch()