File size: 5,505 Bytes
7c851a4
 
 
7b9ce4e
3fe5b5b
7b9ce4e
7c851a4
 
 
7b9ce4e
 
7c851a4
 
7b9ce4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7c851a4
 
7b9ce4e
0cd46d4
7b9ce4e
0cd46d4
 
 
 
7b9ce4e
 
 
 
 
 
7c851a4
3fe5b5b
7b9ce4e
 
 
 
 
 
 
3fe5b5b
 
7c851a4
7b9ce4e
7c851a4
 
7b9ce4e
0cd46d4
7b9ce4e
0cd46d4
 
 
 
7b9ce4e
 
7c851a4
7b9ce4e
7c851a4
 
 
0cd46d4
 
 
 
 
 
7b9ce4e
 
 
 
 
 
 
7c851a4
 
 
7b9ce4e
 
7c851a4
 
 
7b9ce4e
 
 
 
7c851a4
 
7b9ce4e
 
7c851a4
0cd46d4
 
7c851a4
7b9ce4e
 
 
7c851a4
 
 
 
 
7b9ce4e
7c851a4
 
 
 
 
7b9ce4e
7c851a4
 
 
7b9ce4e
7c851a4
7b9ce4e
 
7c851a4
 
 
7b9ce4e
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import gradio as gr
import matplotlib.pyplot as plt
import numpy as np
import math
from datetime import datetime
from matplotlib.ticker import FuncFormatter

# Predefined hyperparameter sets
PARAM_SETS = {
    "Stack-V2-Python": {"E": 0.69123678, "A": 0.01130616 * 1e9, "k": 0.393463, "alpha": 0.18937067},
    "Pile": {"E": 1.28254036, "A": 0.2035367 * 1e9, "k": 0.33027934, "alpha": 0.19479807}
}

def pred_loss(E, A, k, alpha, n, p):
    return E + (A / (n * (1 + np.log(p) * k))) ** alpha

def generate_plot(E, A, k, alpha):
    plt.clf()
    colors = ['#2B83BA', '#7BB7D6', '#ED7D5F', '#D7191C']
    ax = plt.gca()
    for i, p in enumerate([1, 2, 4, 8]):
        x_plot = np.linspace(535813376 * 0.9, 4353203200 * 1.1, 100)
        y_plot = pred_loss(E, A, k, alpha, x_plot, p)
        ax.plot(x_plot, y_plot, marker=None, markersize=1, linewidth=3, color=colors[int(math.log(p, 2))], label=f"$P={p}$")

    ax.legend(fontsize=12)
    # ax.set_xscale("log")
    # ax.set_yscale("log")

    def billions(x, pos):
        if x < 1e9:
            result = ""
        else:
            result = f'{x * 1e-9:.1f}B'
        return result

    ax.xaxis.set_major_formatter(FuncFormatter(billions))
    ax.xaxis.set_minor_formatter(FuncFormatter(billions))
    ax.yaxis.set_major_formatter(FuncFormatter(lambda x, pos: f"{x:.2f}"))
    ax.yaxis.set_minor_formatter(FuncFormatter(lambda x, pos: f"{x:.2f}"))
    ax.set_xlim(535813376 * 0.9, 4353203200 * 1.1)
    ax.set_ylim(ax.get_ylim()[0] * 1, ax.get_ylim()[1] * 1.01)

    ax.text(0.03, 0.03, f"$E={E}$\n$A={A}$\n$k={k}$\n$\\alpha={alpha}$", transform=ax.transAxes, fontsize=10, verticalalignment='bottom', multialignment='left')

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)

    ax.set_xlabel('Parameters (Non-Embedding)', fontsize=12)
    ax.set_ylabel(f'Loss', fontsize=12)
    return plt


OUTPUT_TEMPLATE = """Loss for a {n}B model when P={p} is: **{loss:.5f}**. It is equivalant to:

- A **{n1}B** model with **P=1**;
- A **{n2}B** model with **P=2**;
- A **{n4}B** model with **P=4**;
- A **{n8}B** model with **P=8**;

Note: The equivalent parameters are for reference only. In some reasoning tasks, scaling the parallel streams will obtain more performance gains than the loss benefits!

Enjoy it! 😊"""

def process_inputs(E, A, k, alpha, n, p):
    """Process inputs and return results"""
    n = n * 1e9
    plot = generate_plot(E, A, k, alpha)
    loss = pred_loss(E, A, k, alpha, n, p)

    n1 = n * (k * np.log(p) + 1) / (k * np.log(1) + 1) / 1e9
    n2 = n * (k * np.log(p) + 1) / (k * np.log(2) + 1) / 1e9
    n4 = n * (k * np.log(p) + 1) / (k * np.log(4) + 1) / 1e9
    n8 = n * (k * np.log(p) + 1) / (k * np.log(8) + 1) / 1e9

    print(f"[{datetime.now()}] {E = }, {A = }, {k = }, {alpha = }, {n = }, {p = }")
    
    return plot, OUTPUT_TEMPLATE.format(n=round(n / 1e9, 2), p=p, n1=round(n1, 2), n2=round(n2, 2), n4=round(n4, 2), n8=round(n8, 2), loss=loss)

# Create interface

HEAD = """<div align="center">

# Parallel Scaling Law Visualization

[![Paper](https://img.shields.io/badge/arXiv-2505.10475-red)](https://arxiv.org/abs/2505.10475)
</div>
"""

with gr.Blocks() as demo:
    gr.Markdown(HEAD)
    
    with gr.Row():
        with gr.Column():

            gr.Markdown("""$$
\\text{Loss}=E+\\left(
    \\frac{A}{\\text{Parameters}\\times (1+k\\log P)}
\\right)^{\\alpha}
$$""")
            
            # Input values
            N = gr.Number(value=2.8, label="N: Number of Non-Embedding Model Parameters (in Billion)")
            P = gr.Number(value=4, label="P: Number of Parallel Streams")

            gr.Markdown("---")

            # Hyperparameter selection section
            param_set = gr.Dropdown(
                choices=["Custom"] + list(PARAM_SETS.keys()),
                value=list(PARAM_SETS.keys())[0],
                label="Select our pre-fitted parameters for two datasets"
            )
            
            # Custom parameter inputs
            param_E = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['E'], label="E")
            param_A = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['A'], label="A")
            param_k = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['k'], label="k")
            param_alpha = gr.Number(value=PARAM_SETS["Stack-V2-Python"]['alpha'], label="alpha")
            
        

        plot, output = process_inputs(PARAM_SETS["Stack-V2-Python"]['E'], PARAM_SETS["Stack-V2-Python"]['A'], PARAM_SETS["Stack-V2-Python"]['k'], PARAM_SETS["Stack-V2-Python"]['alpha'], 2.8, 4)
        with gr.Column():

            submit_btn = gr.Button("Calculate")
            # Output section
            plot_output = gr.Plot(label="Scaling Law Curve", value=plot)
            result_output = gr.Markdown(label="Result", value=output)
            
    
    # Auto-fill parameters when selecting predefined sets
    def update_params(param_set):
        if param_set in PARAM_SETS:
            params = PARAM_SETS[param_set]
            return [params["E"], params["A"], params["k"], params["alpha"]]
        return [gr.skip(), gr.skip(), gr.skip(), gr.skip()]
    
    param_set.change(
        update_params,
        inputs=[param_set],
        outputs=[param_E, param_A, param_k, param_alpha]
    )
    
    # Submit button event
    click_event = submit_btn.click(
        process_inputs,
        inputs=[param_E, param_A, param_k, param_alpha,
                N, P],
        outputs=[plot_output, result_output]
    )


demo.launch()