""" h/t to Adam Casson for easy-to-use function to calculate FLOPs, source: https://huggingface.co/spaces/adamcasson/transformer-flops-calculator/blob/main/app.py """ import gradio as gr import plotly.graph_objects as go import numpy as np # Fixed BPE parameters bpe_ps = 4.4 # determined by tokenizer n_ctx_base = 8192 n_heads = 20 n_vocab = 128000 n_layers = 26 # Fixed local model parameters local_d_model = 1024 local_g_size = 1 local_n_ctx = 512 # in bytes local_n_heads = 16 local_n_vocab = 256 local_d_model_k = local_d_model / local_n_heads local_d_ff_multiplier = 4 def openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab, ff_ratio=4): """Open AI method for forward pass FLOPs counting of decoder-only Transformer""" d_attn = d_model // n_heads d_ff = d_model * ff_ratio embeddings = 4 * d_model attn_qkv = 2 * n_layers * d_model * 3 * (d_attn * n_heads) attn_mask = 2 * n_layers * n_ctx * (d_attn * n_heads) attn_project = 2 * n_layers * (d_attn * n_heads) * d_model ff = 2 * n_layers * 2 * d_model * d_ff logits = 2 * d_model * n_vocab return embeddings + attn_qkv + attn_mask + attn_project + ff + logits def cross_attention_flops_per_token(n_layers, n_ctx_cross_attn_kv_len, d_model): ca_qo_proj_flops = ( # Cross Attention QO FLOPs + backward 2 * 4 * d_model**2 ) ca_context_flops = 4 * n_ctx_cross_attn_kv_len * d_model return n_layers * (ca_qo_proj_flops + ca_context_flops) def calculate_flops(blt_ps, d_model, local_n_layers): # BPE calculations n_ctx = int(n_ctx_base / bpe_ps) bpe_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, n_ctx, n_vocab) bpe_per_byte = bpe_flops_per_token / bpe_ps # BLT Global calculations blt_n_ctx = int(n_ctx_base / blt_ps) blt_global_flops_per_token = openai_flops_per_token(n_layers, n_heads, d_model, blt_n_ctx, n_vocab=0) blt_global_flops_per_byte = blt_global_flops_per_token / blt_ps # BLT Local calculations local_models_transformer_flops_per_byte = openai_flops_per_token( local_n_layers, local_n_heads, local_d_model, local_n_ctx, local_n_vocab ) encoder_model_ca_flops_per_byte = cross_attention_flops_per_token( local_n_layers/2, local_n_ctx, local_d_model ) decoder_model_ca_flops_per_byte = cross_attention_flops_per_token( local_n_layers/2, local_n_ctx // blt_ps, local_d_model ) local_models_cross_attention_flops_per_byte = encoder_model_ca_flops_per_byte + decoder_model_ca_flops_per_byte local_models_flops = local_models_transformer_flops_per_byte + local_models_cross_attention_flops_per_byte # Calculate advantage blt_total = local_models_flops + blt_global_flops_per_byte advantage = 100 * ((blt_total - bpe_per_byte) / bpe_per_byte) return { 'bpe_per_byte': bpe_per_byte, 'blt_global': blt_global_flops_per_byte, 'blt_local': local_models_flops, 'blt_total': blt_total, 'advantage': advantage } def create_visualization(blt_ps, d_model, local_n_layers): results = calculate_flops(blt_ps, d_model, local_n_layers) # Create the figure with subplots for better control fig = go.Figure() # Add BPE bar (only for BPE category) fig.add_trace(go.Bar( name='BPE', x=['BPE'], y=[results['bpe_per_byte']], text=[f"{results['bpe_per_byte']:.2e}"], textposition='outside', marker_color='#FF6B6B', width=0.4, showlegend=True )) # Add BLT Global bar (base of stack) fig.add_trace(go.Bar( name='BLT Global', x=['BLT'], y=[results['blt_global']], text=[f"{results['blt_global']:.2e}"], textposition='inside', marker_color='#4ECDC4', width=0.4, showlegend=True )) # Add BLT Local bar (top of stack) fig.add_trace(go.Bar( name='BLT Local', x=['BLT'], y=[results['blt_local']], text=[f"{results['blt_local']:.2e}"], textposition='inside', marker_color='#45B7D1', width=0.4, showlegend=True )) # Update layout with proper stacking and scientific notation fig.update_layout( title={ 'text': f"FLOPs per Byte Comparison
BLT FLOPs comparison: {results['advantage']:.1f}%", 'x': 0.5, 'xanchor': 'center', 'font': {'size': 20} }, xaxis=dict( title="Architecture", tickfont=dict(size=14) ), yaxis=dict( title="FLOPs per Byte", tickformat=".1e", # Scientific notation with 1 decimal tickfont=dict(size=12), gridcolor='lightgray' ), barmode='stack', showlegend=True, height=600, template="plotly_white", font=dict(size=14), bargap=0.3, plot_bgcolor='white' ) fig.add_annotation( x='BLT', y=results['blt_total'] * 1.1, # Position above stacked bar text=f"Total: {results['blt_total']:.2e}", showarrow=False, font=dict(size=12, color="black", weight="bold"), bgcolor="white", bordercolor="black", borderwidth=1 ) # Update traces to ensure proper stacking fig.update_traces(textfont_size=10) return fig # Create Gradio interface with gr.Blocks(title="BLT vs BPE FLOPs Comparison") as demo: gr.Markdown(""" # BLT vs BPE FLOPs Comparison This interactive visualization compares the computational efficiency (FLOPs per byte) between: - **BPE (Byte Pair Encoding)**: Traditional transformer architecture - **BLT (Byte Latent Transformer)**: Novel architecture with Global and Local components with a dynamic patch size to segment bytes. A few things you'll notice: 1. Patch size reduces global model FLOPs but not local model 2. Increasing patch size and global model dimension doesn't change total FLOPs 3. In smaller BLTs, local models constitute a larger portion of the total FLOPs """) with gr.Row(): with gr.Column(scale=1): gr.Markdown("### Adjustable Parameters") blt_ps_slider = gr.Slider( minimum=1.0, maximum=10.0, value=4.4, step=0.1, label="BLT Patch Size (blt_ps)", info="Patch size for BLT architecture" ) d_model_slider = gr.Slider( minimum=512, maximum=8192, value=2560, step=128, label="Model Dimension (d_model)", info="Hidden dimension size of the model" ) local_n_layers_slider = gr.Slider( minimum=2, maximum=24, value=10, step=2, label="Local Model Layers (local_n_layers)", info="Number of layers in the local model" ) gr.Markdown("### Fixed Parameters") gr.Markdown(""" - **BPE's bytes per token**: 4.4 - **BPE/BLT Number of Layers**: 26 - **BPE/BLT Number of Heads**: 20 - **BPE's Vocabulary Size**: 128,000 - **BPE/BLT Context Length**: 8,192 bytes - **Local Model Dimension**: 1,024 - **Local Model Heads**: 16 """) gr.Markdown("### Current Values") info_text = gr.Markdown("") with gr.Column(scale=2): plot = gr.Plot(label="FLOPs Comparison") # Set up interactivity def update_plot(blt_ps, d_model, local_n_layers): fig = create_visualization(blt_ps, d_model, local_n_layers) # Calculate values for info display results = calculate_flops(blt_ps, d_model, local_n_layers) info_str = f""" **BPE FLOPs/byte**: {results['bpe_per_byte']:.2e} **BLT Global FLOPs/byte**: {results['blt_global']:.2e} **BLT Local FLOPs/byte**: {results['blt_local']:.2e} **BLT Total FLOPs/byte**: {results['blt_total']:.2e} """ return fig, info_str # Update plot when any slider changes blt_ps_slider.change( update_plot, inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider], outputs=[plot, info_text] ) d_model_slider.change( update_plot, inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider], outputs=[plot, info_text] ) local_n_layers_slider.change( update_plot, inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider], outputs=[plot, info_text] ) # Initial plot demo.load( update_plot, inputs=[blt_ps_slider, d_model_slider, local_n_layers_slider], outputs=[plot, info_text] ) # Launch the app if __name__ == "__main__": demo.launch()