Simple LLM Training with GPT-2 Architecture

This repository demonstrates how to train a Language Learning Model (LLM) from scratch using the GPT-2 architecture. The model is trained on numerical sequences to learn and predict patterns.

๐Ÿ“Œ Overview

This project implements a full machine learning pipeline:

  • ๐Ÿ“Š Synthetic dataset generation (number sequences)
  • ๐Ÿ”ค Custom tokenizer training
  • ๐Ÿง  Model training using GPT-2
  • ๐Ÿค– Inference capabilities

๐Ÿšง Progress So Far. We have trained a 6.4 million parameter model that:

  • Uses base-16 (hexadecimal) conversion for tokenization.
  • Can add up to 4-digit numbers with 100% accuracy.
  • Is publicly available on Github : ๐Ÿ”— Rajesh-Nair/simple-llm

๐Ÿ—๏ธ Dataset Generator

Synthetic number sequences are generated based on parameters defined in data_config.yaml.

Example Configuration:

  • Number range: 0 - 9999
  • Number of sequences: 100,000
  • Output path: ../simple-llm-data/sequences.txt
  • Delimiters: | (columns), \n (rows)

๐Ÿ”ง To Generate the Dataset:

  1. Update data_config.yaml with your desired parameters.
  2. Run the generator:
    python3 data_generator.py
    

๐ŸŽฏ Training

Step 1: Train the Tokenizer

python3 tokenizer.py

Step 2: Train the Model

python3 trainer.py

Training configurations are managed in train_config.yaml, including:

  • ๐Ÿ”ง Model architecture (layers, heads, embedding size)
  • โš™๏ธ Training hyperparameters (batch size, learning rate)
  • ๐Ÿ’พ Checkpointing and saving
  • โ˜๏ธ Hugging Face Hub integration

๐Ÿ”ข Position Embeddings

๐Ÿ“ Learnable vs. Sinusoidal Embeddings

  • Learnable Embeddings: Adapt to numeric patterns.
  • Sinusoidal Embeddings: Provide a mathematical basis for position understanding.

๐Ÿงฎ Block Position IDs (Abacus Embedding)

Inspired by the Abacus Embedding paper, we use block position IDs.

Example:

  • Input: +1342+879+2221+
  • Block IDs: 012340123012340

๐Ÿ” Why Block Position IDs?

  1. โœ… Commutative Support: a + b = b + a โ€” block IDs reinforce this.
  2. ๐Ÿง  Digit Alignment: Helps align units, tens, hundreds positions for easier digit-wise processing.

๐Ÿ”„ Digit Reversal

As part of preprocessing:

  • 5672 โ†’ 2765 (reversed)
  • Output is reversed back during evaluation.

๐Ÿ“ˆ Benefits of Reversal

  1. ๐Ÿง’ Human-like learning: Mimics the left-to-right addition humans use.
  2. ๐ŸŽฏ Causal attention compatibility: Enables better carryover handling.
  3. ๐Ÿ“š Research-backed approach: Digit reversal has been successfully used in several papers including:

๐Ÿงฉ Tokenization Strategy

Tokenization is critical for arithmetic modeling. Our approach:

  1. ๐Ÿ“ Shortens sequences: Optimizes context window usage.
  2. ๐Ÿงฌ Boosts generalization: Learns across number patterns.
  3. ๐Ÿ”„ Uses base conversion (e.g., decimal โ†’ hexadecimal) for compact, arithmetic-aware tokens.
  4. ๐Ÿง  Preserves arithmetic logic: Even in higher bases, rules still apply.

We're experimenting with different bases to improve efficiency further.


๐Ÿ” Multi-token Prediction

Predicting multiple tokens at once increases efficiency. This is possible since we have reversed all the numbers.

Example: To predict two token at a time, we see output 99 to appear in the first iteration

Input (reversed):     +12+873+PPPPPPPP      (P = padding tokens)
Output (reversed):    PPPPPP99PPPPPPPP      (P = padding tokens)
Position IDs:         0120123000000000

We're currently supporting 2-token prediction and it works well ๐Ÿ”— mirajnair/simple-llm-gpt2-v2.0

..And we are expanding on generalizing this method - i.e output token at the earliest opportunity so we can have 2 or more predicted in one go.

๐Ÿ“Š Attention Visualization

Visualizing attention patterns reveals how the model processes arithmetic operations. Below is an example showing attention patterns for the addition problem: 101 + 1002 = 1103 (represented in reversed form as +101+2001+3011+).

Layer 1 Attention Patterns

Layer 1 Attention Visualization

In this visualization:

  • Bright vertical bars at positions 1, 5, and 10 show how the model focuses on unit digits from both inputs and the output
  • The model learns to align corresponding digit positions (units with units, tens with tens, etc.)
  • Attention patterns reveal how information flows during the addition process, including carry operations

This confirms our block position ID approach helps the model understand the commutative nature of addition and properly align digits for arithmetic operations.

The visualization demonstrates how the model has learned to focus on relevant digits when performing calculations, similar to how humans process arithmetic problems.

๐ŸŽฏ Performance Results

We've rigorously tested our model's arithmetic capabilities with impressive results:

Addition Performance Test

  • Test Set: 10,000 random pairs of 4-digit numbers
  • Accuracy: 100%
  • Consistency: Maintained perfect accuracy across multiple test runs

This perfect accuracy demonstrates that our approach successfully teaches the model to perform addition operations with complete reliability, even on previously unseen number combinations. The combination of our specialized tokenization strategy, position encoding, and multi-token prediction enables the model to generalize arithmetic rules effectively.

These results validate our architectural choices and confirm that transformer-based models can master fundamental arithmetic operations when properly designed.

๐Ÿš€ Next Steps

  1. Multi-token Generation:

    • We've proved the model can output more than 1 token at a time
    • Test if model can generate all tokens in one-go (greedy generation)
  2. Scale Up:

    • Increase the length/number of digits in operations
    • Scale up model size for more complex operations
  3. Length Generalization:

  4. Add batch prediction:

    • Implement parallel processing of multiple arithmetic operations
    • Optimize throughput by processing multiple sequences simultaneously
    • Reduce overall inference time for bulk operations
  5. KV cache:

    • Implement key-value caching to reduce redundant computations
    • Optimize memory usage during autoregressive generation
    • Speed up sequential token generation by reusing previous computations
Downloads last month
4
Safetensors
Model size
6.33M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support