Spaces:
Running
Running
File size: 5,505 Bytes
7c851a4 7b9ce4e 3fe5b5b 7b9ce4e 7c851a4 7b9ce4e 7c851a4 7b9ce4e 7c851a4 7b9ce4e 0cd46d4 7b9ce4e 0cd46d4 7b9ce4e 7c851a4 3fe5b5b 7b9ce4e 3fe5b5b 7c851a4 7b9ce4e 7c851a4 7b9ce4e 0cd46d4 7b9ce4e 0cd46d4 7b9ce4e 7c851a4 7b9ce4e 7c851a4 0cd46d4 7b9ce4e 7c851a4 7b9ce4e 7c851a4 7b9ce4e 7c851a4 7b9ce4e 7c851a4 0cd46d4 7c851a4 7b9ce4e 7c851a4 7b9ce4e 7c851a4 7b9ce4e 7c851a4 7b9ce4e 7c851a4 7b9ce4e 7c851a4 7b9ce4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import math
from datetime import datetime
from matplotlib.ticker import FuncFormatter
# Predefined hyperparameter sets
PARAM_SETS = {
"Stack-V2-Python": {"E": 0.69123678, "A": 0.01130616 * 1e9, "k": 0.393463, "alpha": 0.18937067},
"Pile": {"E": 1.28254036, "A": 0.2035367 * 1e9, "k": 0.33027934, "alpha": 0.19479807}
}
def pred_loss(E, A, k, alpha, n, p):
return E + (A / (n * (1 + np.log(p) * k))) ** alpha
def generate_plot(E, A, k, alpha):
plt.clf()
colors = ['#2B83BA', '#7BB7D6', '#ED7D5F', '#D7191C']
ax = plt.gca()
for i, p in enumerate([1, 2, 4, 8]):
x_plot = np.linspace(535813376 * 0.9, 4353203200 * 1.1, 100)
y_plot = pred_loss(E, A, k, alpha, x_plot, p)
ax.plot(x_plot, y_plot, marker=None, markersize=1, linewidth=3, color=colors[int(math.log(p, 2))], label=f"$P={p}$")
ax.legend(fontsize=12)
# ax.set_xscale("log")
# ax.set_yscale("log")
def billions(x, pos):
if x < 1e9:
result = ""
else:
result = f'{x * 1e-9:.1f}B'
return result
ax.xaxis.set_major_formatter(FuncFormatter(billions))
ax.xaxis.set_minor_formatter(FuncFormatter(billions))
ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.2f}"))
ax.yaxis.set_minor_formatter(FuncFormatter(lambda x, pos: f"{x:.2f}"))
ax.set_xlim(535813376 * 0.9, 4353203200 * 1.1)
ax.set_ylim(ax.get_ylim()[0] * 1, ax.get_ylim()[1] * 1.01)
ax.text(0.03, 0.03, f"$E={E}$\n$A={A}$\n$k={k}$\n$\\alpha={alpha}$", transform=ax.transAxes, fontsize=10, verticalalignment='bottom', multialignment='left')
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.set_xlabel('Parameters (Non-Embedding)', fontsize=12)
ax.set_ylabel(f'Loss', fontsize=12)
return plt
OUTPUT_TEMPLATE = """Loss for a {n}B model when P={p} is: **{loss:.5f}**. It is equivalant to:
- A **{n1}B** model with **P=1**;
- A **{n2}B** model with **P=2**;
- A **{n4}B** model with **P=4**;
- A **{n8}B** model with **P=8**;
Note: The equivalent parameters are for reference only. In some reasoning tasks, scaling the parallel streams will obtain more performance gains than the loss benefits!
Enjoy it! π"""
def process_inputs(E, A, k, alpha, n, p):
"""Process inputs and return results"""
n = n * 1e9
plot = generate_plot(E, A, k, alpha)
loss = pred_loss(E, A, k, alpha, n, p)
n1 = n * (k * np.log(p) + 1) / (k * np.log(1) + 1) / 1e9
n2 = n * (k * np.log(p) + 1) / (k * np.log(2) + 1) / 1e9
n4 = n * (k * np.log(p) + 1) / (k * np.log(4) + 1) / 1e9
n8 = n * (k * np.log(p) + 1) / (k * np.log(8) + 1) / 1e9
print(f"[{datetime.now()}] {E = }, {A = }, {k = }, {alpha = }, {n = }, {p = }")
return plot, OUTPUT_TEMPLATE.format(n=round(n / 1e9, 2), p=p, n1=round(n1, 2), n2=round(n2, 2), n4=round(n4, 2), n8=round(n8, 2), loss=loss)
# Create interface
HEAD = """<div align="center">
# Parallel Scaling Law Visualization
[](https://arxiv.org/abs/2505.10475)
</div>
"""
with gr.Blocks() as demo:
gr.Markdown(HEAD)
with gr.Row():
with gr.Column():
gr.Markdown("""$$
\\text{Loss}=E+\\left(
\\frac{A}{\\text{Parameters}\\times (1+k\\log P)}
\\right)^{\\alpha}
$$""")
# Input values
N = gr.Number(value=2.8, label="N: Number of Non-Embedding Model Parameters (in Billion)")
P = gr.Number(value=4, label="P: Number of Parallel Streams")
gr.Markdown("---")
# Hyperparameter selection section
param_set = gr.Dropdown(
choices=["Custom"] + list(PARAM_SETS.keys()),
value=list(PARAM_SETS.keys())[0],
label="Select our pre-fitted parameters for two datasets"
)
# Custom parameter inputs
param_E = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['E'], label="E")
param_A = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['A'], label="A")
param_k = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['k'], label="k")
param_alpha = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['alpha'], label="alpha")
plot, output = process_inputs(PARAM_SETS["Stack-V2-Python"]['E'], PARAM_SETS["Stack-V2-Python"]['A'], PARAM_SETS["Stack-V2-Python"]['k'], PARAM_SETS["Stack-V2-Python"]['alpha'], 2.8, 4)
with gr.Column():
submit_btn = gr.Button("Calculate")
# Output section
plot_output = gr.Plot(label="Scaling Law Curve", value=plot)
result_output = gr.Markdown(label="Result", value=output)
# Auto-fill parameters when selecting predefined sets
def update_params(param_set):
if param_set in PARAM_SETS:
params = PARAM_SETS[param_set]
return [params["E"], params["A"], params["k"], params["alpha"]]
return [gr.skip(), gr.skip(), gr.skip(), gr.skip()]
param_set.change(
update_params,
inputs=[param_set],
outputs=[param_E, param_A, param_k, param_alpha]
)
# Submit button event
click_event = submit_btn.click(
process_inputs,
inputs=[param_E, param_A, param_k, param_alpha,
N, P],
outputs=[plot_output, result_output]
)
demo.launch() |