File size: 3,207 Bytes
1b94cd6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4191ca6
1b94cd6
4191ca6
 
 
 
 
1b94cd6
 
 
 
 
 
b4db2f5
1b94cd6
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr

def estimate_transformer_stats(batch_size, seq_len, num_layers, hidden_dim, vocab_size, show_breakdown):
    B = batch_size
    S = seq_len
    L = num_layers
    D = hidden_dim
    V = vocab_size

    # --- Parameters ---
    num_params = L * 12 * (D ** 2) + D * V

    # --- FLOPs --- (using 2 * m * n * p per matmul)
    attn_proj_flops = 2 * 3 * S * D * D
    attn_score_flops = 2 * S * D * S
    attn_out_proj_flops = 2 * S * D * D
    ffn_flops = 2 * 2 * S * D * 4 * D
    logit_flops = 2 * S * D * V / L

    total_layer_flops = attn_proj_flops + attn_score_flops + attn_out_proj_flops + ffn_flops + logit_flops
    total_flops = 6 * B * L * total_layer_flops

    output_lines = [
        f"Parameters: P = 12 * L * D^2 + D * V",
        f"           = 12 * {L} * {D}^2 + {D} * {V} = {num_params:.2e}",
        f"",
        f"FLOPs per layer (per sequence):",
        f"  Attention Projections (QKV): 2 * 3 * S * D^2 = 2 * 3 * {S} * {D}^2 = {attn_proj_flops:.2e}",
        f"  Attention Scores (QKᵀ):      2 * S * D * S = 2 * {S} * {D} * {S} = {attn_score_flops:.2e}",
        f"  Attention Output Proj:       2 * S * D^2   = 2 * {S} * {D}^2 = {attn_out_proj_flops:.2e}",
        f"  Feedforward Network:         2 * 2 * S * D * 4D = 2*2*{S}*{D}*{4*D} = {ffn_flops:.2e}",
        f"  Logits:                      2 * S * D * V / L = 2*{S}*{D}*{V} / {L} = {logit_flops:.2e}",
        f"",
        f"Layer Total FLOPs = {total_layer_flops:.2e}",
        f"",
        f"Total Training FLOPs = 6 * B * L * Layer_FLOPs",
        f"                    = 6 * {B} * {L} * {total_layer_flops:.2e} = {total_flops:.2e}"
    ]

    if show_breakdown:
        total_all = attn_proj_flops + attn_score_flops + attn_out_proj_flops + ffn_flops + logit_flops
        output_lines.append("\nComponent-wise totals across training batch:")
        output_lines.append(f"  - QKV Projections: {attn_proj_flops * B * L:.2e} ({100 * attn_proj_flops / total_all:.1f}%)")
        output_lines.append(f"  - Attention Scores: {attn_score_flops * B * L:.2e} ({100 * attn_score_flops / total_all:.1f}%)")
        output_lines.append(f"  - Attention Output: {attn_out_proj_flops * B * L:.2e} ({100 * attn_out_proj_flops / total_all:.1f}%)")
        output_lines.append(f"  - FFN: {ffn_flops * B * L:.2e} ({100 * ffn_flops / total_all:.1f}%)")
        output_lines.append(f"  - Logits: {logit_flops * B * L:.2e} ({100 * logit_flops / total_all:.1f}%)")

    return "\n".join(output_lines)

iface = gr.Interface(
    fn=estimate_transformer_stats,
    inputs=[
        gr.Number(label="Batch Size", value=1),
        gr.Number(label="Sequence Length", value=2048),
        gr.Number(label="Number of Layers", value=24),
        gr.Number(label="Hidden Size (d_model)", value=2048),
        gr.Number(label="Vocabulary Size", value=50272),
        gr.Checkbox(label="Show FLOPs Breakdown", value=True),
    ],
    outputs=gr.Textbox(label="Estimates"),
    title="Transformer Parameter and FLOPs Estimator",
    description="Estimates parameter count and training FLOPs for decoder-only Transformers (like OPT/GPT). Shows formulas and per-component breakdown."
)

if __name__ == "__main__":
    iface.launch()