⚛️ Fused vs Unfused Kernels

Overview

In this puzzle, we explore the performance benefits of kernel fusion by implementing and comparing two approaches to the LayerNorm and Linear operation:

  1. Unfused approach: Executes LayerNorm and Linear as separate operations
  2. Fused kernel: Combines LayerNorm and Linear operations into a single GPU kernel

This comparison demonstrates how kernel fusion can significantly improve performance by:

  • Reducing memory bandwidth usage
  • Minimizing kernel launch overhead
  • Improving cache utilization
  • Eliminating intermediate memory allocations

Key concepts

In this puzzle, you’ll master:

  • Kernel fusion techniques for combining multiple operations
  • Memory bandwidth optimization through fused operations
  • Performance benchmarking of different kernel implementations
  • Numerical stability in fused operations
  • PyTorch custom operation integration

The mathematical operations we’re fusing are:

  1. LayerNorm: \[\Large \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]

  2. Linear: \[\Large \text{Linear}(x) = Wx + b \]

When fused, we compute: \[\Large \text{Fused}(x) = W(\gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta) + b \]

Understanding LayerNorm

LayerNorm is a normalization technique that helps stabilize and accelerate the training of deep neural networks. Let’s break down its components and parameters:

What LayerNorm does

  1. Normalization: LayerNorm normalizes the activations across the features (hidden dimensions) for each sample independently. This means:

    • For each sequence position, it computes statistics across the hidden dimension
    • Each sample in the batch is normalized independently
    • This is different from BatchNorm, which normalizes across the batch dimension
  2. Parameters:

    • \(\gamma\) (scale): A learnable parameter vector that allows the network to learn the optimal scale for each feature
    • \(\beta\) (shift): A learnable parameter vector that allows the network to learn the optimal shift for each feature
    • \(\epsilon\): A small constant (1e-5) added to the variance to prevent division by zero

What LayerNorm does in practice

LayerNorm performs several crucial functions in deep neural networks:

  1. Feature standardization:

    • Transforms each feature to have zero mean and unit variance
    • Makes the network’s learning process more stable
    • Helps prevent the “internal covariate shift” problem where the distribution of layer inputs changes during training
  2. Gradient flow:

    • Improves gradient flow through the network
    • Prevents vanishing/exploding gradients
    • Makes training more efficient by allowing higher learning rates
  3. Regularization effect:

    • Acts as a form of implicit regularization
    • Helps prevent overfitting by normalizing the feature distributions
    • Makes the network more robust to input variations
  4. Sequence modeling:

    • Particularly effective in transformer architectures
    • Helps maintain consistent signal magnitude across different sequence lengths
    • Enables better handling of variable-length sequences
  5. Training dynamics:

    • Accelerates training convergence
    • Reduces the need for careful learning rate tuning
    • Makes the network less sensitive to weight initialization

Mathematical components

  1. Mean Calculation (\(\mu\)): \[\Large \mu = \frac{1}{H} \sum_{i=1}^{H} x_i \]

    • Computes the mean across the hidden dimension (H)
    • Each sequence position has its own mean
  2. Variance Calculation (\(\sigma^2\)): \[\Large \sigma^2 = \frac{1}{H} \sum_{i=1}^{H} (x_i - \mu)^2 \]

    • Computes the variance across the hidden dimension
    • Used to scale the normalized values
  3. Normalization and Scaling: \[\Large \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]

    • First normalizes the input to have zero mean and unit variance
    • Then applies learnable scale (\(\gamma\)) and shift (\(\beta\)) parameters
    • The \(\odot\) symbol represents elementwise multiplication (Hadamard product)
    • For example, if \(\gamma = [1.2, 0.8, 1.5]\) and normalized input is \([0.5, -0.3, 0.7]\), then \(\gamma \odot x = [0.6, -0.24, 1.05]\)

Why LayerNorm is important

  1. Training Stability:

    • Prevents activations from growing too large or small
    • Helps maintain consistent signal magnitude throughout the network
  2. Feature Learning:

    • The scale (\(\gamma\)) and shift (\(\beta\)) parameters allow the network to learn which features are important
    • Can effectively learn to ignore or emphasize certain features
  3. Independence:

    • Unlike BatchNorm, LayerNorm’s statistics are computed independently for each sample
    • Makes it more suitable for variable-length sequences and small batch sizes

Configuration

  • Batch size: BATCH_SIZE = 4
  • Sequence length: SEQ_LEN = 4
  • Hidden dimension: HIDDEN_DIM = 8
  • Output dimension: OUTPUT_DIM = 16
  • Epsilon: EPS = 1e-5
  • Data type: DType.float32

Implementation approaches

1. Unfused implementation

The unfused approach executes operations separately using multiple kernels. Here are some of the kernels we wrote in the previous chapters:

Matrix multiplication kernel

From Puzzle 14, we reuse the tiled matrix multiplication kernel for the linear transformation. This kernel includes bounds checking to handle variable matrix dimensions safely:

# Idiomatic tiled matmul from p14.mojo - adapted for [batch*seq, hidden] @ [hidden, output] -> [batch*seq, output]
fn matmul_idiomatic_tiled[
    a_layout: Layout,
    b_layout: Layout,
    out_layout: Layout,
    rows: Int,
    cols: Int,
    inner_dim: Int,
](
    output: LayoutTensor[mut=True, dtype, out_layout],
    a: LayoutTensor[mut=False, dtype, a_layout],
    b: LayoutTensor[mut=False, dtype, b_layout],
):
    """Idiomatic tiled matmul following p14.mojo exactly."""
    local_row = thread_idx.x
    local_col = thread_idx.y
    tiled_row = block_idx.y * TPB + local_row
    tiled_col = block_idx.x * TPB + local_col

    # Get the tile of the output matrix that this thread block is responsible for
    out_tile = output.tile[TPB, TPB](block_idx.x, block_idx.y)
    a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc().fill(0)
    b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc().fill(0)

    var acc: output.element_type = 0

    alias load_a_layout = Layout.row_major(1, TPB)
    alias load_b_layout = Layout.row_major(TPB, 1)

    for idx in range((inner_dim + TPB - 1) // TPB):
        # Get tiles from A and B matrices
        a_tile = a.tile[TPB, TPB](block_idx.x, idx)
        b_tile = b.tile[TPB, TPB](idx, block_idx.y)

        # Asynchronously copy tiles to shared memory
        copy_dram_to_sram_async[thread_layout=load_a_layout](a_shared, a_tile)
        copy_dram_to_sram_async[thread_layout=load_b_layout](b_shared, b_tile)

        # Wait for all async copies to complete
        async_copy_wait_all()
        barrier()

        # Compute partial matrix multiplication for this tile
        @parameter
        for k in range(TPB):
            acc += a_shared[local_row, k] * b_shared[k, local_col]

        barrier()

    # Write final result with bounds checking (needed for variable matrix sizes)
    if tiled_row < rows and tiled_col < cols:
        out_tile[local_row, local_col] = acc


Transpose kernel

For efficient memory access patterns, we use a transpose kernel with shared memory tiling:

fn transpose_kernel[
    layout_in: Layout,
    layout_out: Layout,
    rows: Int,
    cols: Int,
](
    output: LayoutTensor[mut=True, dtype, layout_out],
    input: LayoutTensor[mut=False, dtype, layout_in],
):
    """Transpose matrix using shared memory tiling for coalesced access.
    We will learn more about coalesced access in the next part.
    """
    shared_tile = tb[dtype]().row_major[TPB, TPB]().shared().alloc()

    local_row = thread_idx.y
    local_col = thread_idx.x

    global_row = block_idx.y * TPB + local_row
    global_col = block_idx.x * TPB + local_col

    if global_row < rows and global_col < cols:
        shared_tile[local_row, local_col] = input[global_row, global_col]
    else:
        shared_tile[local_row, local_col] = 0.0

    barrier()

    out_row = block_idx.x * TPB + local_row
    out_col = block_idx.y * TPB + local_col

    # Store data from shared memory to global memory (coalesced write)
    # Note: we transpose the shared memory access pattern
    if out_row < cols and out_col < rows:
        output[out_row, out_col] = shared_tile[local_col, local_row]


Bias addition kernel

A simple elementwise addition kernel for adding the bias term:

fn add_bias_kernel[
    input_layout: Layout,
    bias_layout: Layout,
    output_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    output_dim: Int,
](
    output: LayoutTensor[mut=True, dtype, output_layout],
    input: LayoutTensor[mut=False, dtype, input_layout],
    bias: LayoutTensor[mut=False, dtype, bias_layout],
):
    """Simple bias addition."""
    batch_idx = block_idx.x
    seq_idx = block_idx.y
    out_idx = thread_idx.x

    if batch_idx >= batch_size or seq_idx >= seq_len or out_idx >= output_dim:
        return

    output[batch_idx, seq_idx, out_idx] = input[
        batch_idx, seq_idx, out_idx
    ] + rebind[Scalar[dtype]](bias[out_idx])


LayerNorm kernel

Now complete this kernel to implement the LayerNorm operation. You’ll need to:

  1. Compute mean \(\mu\) and variance \(\sigma^2\) for each sequence position
  2. Normalize the input using these statistics
  3. Apply the scale \(\gamma\) and shift \(\beta\) parameters
fn layernorm_kernel[
    input_layout: Layout,
    ln_params_layout: Layout,
    output_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    hidden_dim: Int,
](
    output: LayoutTensor[mut=True, dtype, output_layout],
    input: LayoutTensor[mut=False, dtype, input_layout],
    ln_weight: LayoutTensor[mut=False, dtype, ln_params_layout],
    ln_bias: LayoutTensor[mut=False, dtype, ln_params_layout],
):
    batch_idx = block_idx.x
    seq_idx = block_idx.y
    hidden_idx = thread_idx.x

    if (
        batch_idx >= batch_size
        or seq_idx >= seq_len
        or hidden_idx >= hidden_dim
    ):
        return

    # Compute statistics for this sequence position (redundant but simple)
    var sum_val: Scalar[dtype] = 0
    var sq_sum: Scalar[dtype] = 0

    # FILL ME IN (roughly 11 lines)


Implementation steps:

  1. First, compute mean and variance using parallel reduction
  2. Then normalize the input using these statistics
  3. Finally, apply the scale and shift parameters

Characteristics of unfused approach:

  • Multiple kernel launches (LayerNorm → MatMul → Bias)
  • Intermediate tensor allocations between operations
  • More memory bandwidth usage due to separate passes
  • Simpler implementation with clear separation of concerns
  • Easier to debug as each operation is isolated
Tips
  1. Thread organization:

    • Use one thread block per sequence position (grid: [batch_size, seq_len])
    • Each thread handles one hidden dimension element
    • Avoid redundant computation by computing statistics once per sequence
  2. Memory access:

    • Access input tensor with [batch_idx, seq_idx, hidden_idx]
    • Access output tensor with [batch_idx, seq_idx, hidden_idx]
    • Access LayerNorm parameters with [hidden_idx]
  3. Numerical stability:

    • Add epsilon (1e-5) before taking square root
    • Use rebind[Scalar[dtype]] for proper type casting
    • Compute variance as (sq_sum / hidden_dim) - (mean * mean)
  4. Performance:

    • Compute mean and variance in a single pass
    • Reuse computed statistics for all elements in sequence
    • Avoid unnecessary memory barriers

Running the code

To test your unfused implementation, run:

uv run poe p20 --unfused
pixi run p20 --unfused

Your output will look like this:

Testing with dimensions: [4, 4, 8] -> [4, 4, 16]
✅ Loaded Mojo operations library
============================================================
   Puzzle 20: UNFUSED Algorithm Test & Benchmark
============================================================

🧪 Correctness Testing for UNFUSED Algorithm
====================================================

Testing Reference PyTorch Implementation
-----------------------------------------------
✅ Reference PyTorch
   Max difference: 0.00e+00
   Result: ✅ CORRECT

Testing CPU Implementation
---------------------------------
✅ Using Mojo fused kernel (CPU)
   Max difference: 1.86e-08
   Result: ✅ CORRECT

Testing GPU Unfused Implementation
-----------------------------------------
✅ Using Mojo unfused kernel (GPU)
   Max difference: 1.86e-08
   Result: ✅ CORRECT

Correctness Summary:
   - Reference:   ✅ CORRECT
   - CPU:         ✅ CORRECT
   - GPU unfused: ✅ CORRECT

   Overall Correctness: ✅ ALL CORRECT

Benchmarking CPU vs GPU UNFUSED
------------------------------------------
   Testing CPU performance...
   CPU: 3173.70ms (50 iterations)
   Testing GPU unfused performance...
   GPU unfused: 3183.57ms (50 iterations)

   GPU unfused vs CPU: 1.00x slower
   CPU wins (GPU overhead > computation benefit)

UNFUSED Algorithm Test Completed!

Solution

fn layernorm_kernel[
    input_layout: Layout,
    ln_params_layout: Layout,
    output_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    hidden_dim: Int,
](
    output: LayoutTensor[mut=True, dtype, output_layout],
    input: LayoutTensor[mut=False, dtype, input_layout],
    ln_weight: LayoutTensor[mut=False, dtype, ln_params_layout],
    ln_bias: LayoutTensor[mut=False, dtype, ln_params_layout],
):
    batch_idx = block_idx.x
    seq_idx = block_idx.y
    hidden_idx = thread_idx.x

    if (
        batch_idx >= batch_size
        or seq_idx >= seq_len
        or hidden_idx >= hidden_dim
    ):
        return

    # Compute statistics for this sequence position (redundant but simple)
    var sum_val: Scalar[dtype] = 0
    var sq_sum: Scalar[dtype] = 0

    @parameter
    for h in range(hidden_dim):
        val = input[batch_idx, seq_idx, h]
        sum_val += rebind[Scalar[dtype]](val)
        sq_sum += rebind[Scalar[dtype]](val * val)

    mean_val = sum_val / hidden_dim
    var_val = (sq_sum / hidden_dim) - (mean_val * mean_val)
    inv_std = 1.0 / sqrt(var_val + 1e-5)

    # Apply LayerNorm to this element
    input_val = input[batch_idx, seq_idx, hidden_idx]
    normalized = (input_val - mean_val) * inv_std * rebind[Scalar[dtype]](
        ln_weight[hidden_idx]
    ) + rebind[Scalar[dtype]](ln_bias[hidden_idx])
    output[batch_idx, seq_idx, hidden_idx] = normalized


The unfused implementation follows a straightforward approach where each thread handles one element of the output tensor. Let’s break down the key components:

  1. Thread and Block Organization:

    batch_idx = block_idx.x
    seq_idx = block_idx.y
    hidden_idx = thread_idx.x
    
    • Each thread block handles one sequence position in the batch
    • Grid dimensions: [batch_size, seq_len]
    • Each thread processes one element in the hidden dimension
    • Early return if indices are out of bounds:
      if (batch_idx >= batch_size or seq_idx >= seq_len or hidden_idx >= hidden_dim):
          return
      
  2. Statistics Computation:

    var sum_val: Scalar[dtype] = 0
    var sq_sum: Scalar[dtype] = 0
    
    @parameter
    for h in range(hidden_dim):
        val = input[batch_idx, seq_idx, h]
        sum_val += rebind[Scalar[dtype]](val)
        sq_sum += rebind[Scalar[dtype]](val * val)
    
    • Compute sum and squared sum in a single pass
    • Use @parameter for compile-time loop unrolling
    • Proper type casting with rebind[Scalar[dtype]]
    • Calculate mean and variance:
      mean_val = sum_val / hidden_dim
      var_val = (sq_sum / hidden_dim) - (mean_val * mean_val)
      inv_std = 1.0 / sqrt(var_val + 1e-5)
      
  3. Normalization and Scaling:

    input_val = input[batch_idx, seq_idx, hidden_idx]
    normalized = (input_val - mean_val) * inv_std * rebind[Scalar[dtype]](
        ln_weight[hidden_idx]
    ) + rebind[Scalar[dtype]](ln_bias[hidden_idx])
    output[batch_idx, seq_idx, hidden_idx] = normalized
    
    • Apply normalization: \[\Large \text{normalized} = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
    • Scale with learnable parameter γ (ln_weight)
    • Add learnable bias β (ln_bias)
    • Store result in output tensor
  4. Performance Characteristics:

    • Each thread computes statistics independently
    • No shared memory usage (simple but less efficient)
    • Memory access pattern:
      • Input: [batch_idx, seq_idx, h]
      • Output: [batch_idx, seq_idx, hidden_idx]
      • Parameters: [hidden_idx]
    • Numerical stability ensured by:
      • Adding epsilon (1e-5) before square root
      • Using proper type casting
      • Computing variance in a numerically stable way
  5. Implementation Details:

    • Type Safety:

      • Use Scalar[dtype] for intermediate calculations
      • rebind[Scalar[dtype]] for proper type casting
      • Ensures consistent floating-point precision
    • Memory Access:

      • Coalesced reads from input tensor
      • Coalesced writes to output tensor
      • Sequential access to LayerNorm parameters
    • Computation Flow:

      • Statistics computation: \[\Large O(H) \text{ operations per thread} \]
      • Normalization: \[\Large O(1) \text{ operations per thread} \]
      • Total complexity: \[\Large O(H) \text{ per output element} \]
    • Limitations:

      • Redundant computation of statistics
      • No shared memory for intermediate results
      • High memory bandwidth usage
      • Multiple kernel launches required

This implementation is correct but not optimal for performance, as shown in the benchmark results where it’s slightly slower than the CPU version. The fused implementation will address these performance limitations by:

  • Computing statistics once per sequence
  • Reusing normalized values
  • Reducing memory traffic
  • Eliminating intermediate tensor allocations

2. Fused kernel implementation

The fused kernel combines LayerNorm and Linear operations into a single GPU kernel:

fn minimal_fused_kernel[
    input_layout: Layout,
    ln_params_layout: Layout,
    weight_layout: Layout,
    bias_layout: Layout,
    output_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    hidden_dim: Int,
    output_dim: Int,
](
    output: LayoutTensor[mut=True, dtype, output_layout],
    input: LayoutTensor[mut=False, dtype, input_layout],
    ln_weight: LayoutTensor[mut=False, dtype, ln_params_layout],
    ln_bias: LayoutTensor[mut=False, dtype, ln_params_layout],
    linear_weight: LayoutTensor[mut=False, dtype, weight_layout],
    linear_bias: LayoutTensor[mut=False, dtype, bias_layout],
):
    """Minimal fused kernel - one thread per sequence position to avoid redundancy.
    """
    # Grid: (batch_size, seq_len) - one thread block per sequence position
    # Block: (1,) - single thread per sequence position to avoid redundant computation
    batch_idx = block_idx.x
    seq_idx = block_idx.y

    if batch_idx >= batch_size or seq_idx >= seq_len:
        return

    # Step 1: Compute LayerNorm statistics once per sequence position

    # FILL IN roughly 10 lines

    # Step 2: Compute all outputs for this sequence position

    # FILL IN roughly 10 lines


Key optimizations:

  • Single kernel launch instead of two
  • Shared memory for intermediate results
  • Coalesced memory access patterns
  • Reduced memory bandwidth usage
  • No intermediate tensor allocations
Tips
  1. Thread organization:

    • One thread block per sequence position (grid: [batch_size, seq_len])
    • Single thread per sequence position to avoid redundancy
    • Compute all outputs for each sequence position in one thread
  2. Memory access:

    • Access input tensor with [batch_idx, seq_idx, h]
    • Access output tensor with [batch_idx, seq_idx, out_idx]
    • Access weights with [out_idx, h] for linear layer
  3. Computation flow:

    • Compute LayerNorm statistics once per sequence
    • Reuse normalized values for all output dimensions
    • Combine normalization and linear transformation
  4. Performance:

    • Avoid redundant computation of statistics
    • Minimize memory traffic by fusing operations
    • Use proper type casting with rebind[Scalar[dtype]]

Running the code

To test your fused implementation, run:

uv run poe p20 --fused
pixi run p20 --fused

Your output will look like this:

Testing with dimensions: [4, 4, 8] -> [4, 4, 16]
✅ Loaded Mojo operations library
============================================================
   Puzzle 20: FUSED Algorithm Test & Benchmark
============================================================

🧪 Correctness Testing for FUSED Algorithm
==================================================

Testing Reference PyTorch Implementation
-----------------------------------------------
✅ Reference PyTorch
   Max difference: 0.00e+00
   Result: ✅ CORRECT

Testing CPU Implementation
---------------------------------
✅ Using Mojo fused kernel (CPU)
   Max difference: 1.86e-08
   Result: ✅ CORRECT

Testing GPU Fused Implementation
---------------------------------------
✅ Using Mojo fused kernel (GPU)
   Max difference: 1.86e-08
   Result: ✅ CORRECT

Correctness Summary:
   - Reference:   ✅ CORRECT
   - CPU:         ✅ CORRECT
   - GPU fused: ✅ CORRECT

   Overall Correctness: ✅ ALL CORRECT

⚡ Benchmarking CPU vs GPU FUSED
----------------------------------------
   Testing CPU performance...
   CPU: 3144.75ms (50 iterations)
   Testing GPU fused performance...
   GPU fused: 3116.11ms (50 iterations)

   GPU fused vs CPU: 1.01x faster
   GPU fused wins!

FUSED Algorithm Test Completed!

Solution

fn minimal_fused_kernel[
    input_layout: Layout,
    ln_params_layout: Layout,
    weight_layout: Layout,
    bias_layout: Layout,
    output_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    hidden_dim: Int,
    output_dim: Int,
](
    output: LayoutTensor[mut=True, dtype, output_layout],
    input: LayoutTensor[mut=False, dtype, input_layout],
    ln_weight: LayoutTensor[mut=False, dtype, ln_params_layout],
    ln_bias: LayoutTensor[mut=False, dtype, ln_params_layout],
    linear_weight: LayoutTensor[mut=False, dtype, weight_layout],
    linear_bias: LayoutTensor[mut=False, dtype, bias_layout],
):
    """Minimal fused kernel - one thread per sequence position to avoid redundancy.
    """
    # Grid: (batch_size, seq_len) - one thread block per sequence position
    # Block: (1,) - single thread per sequence position to avoid redundant computation
    batch_idx = block_idx.x
    seq_idx = block_idx.y

    if batch_idx >= batch_size or seq_idx >= seq_len:
        return

    # Step 1: Compute LayerNorm statistics once per sequence position
    var sum_val: Scalar[dtype] = 0
    var sq_sum: Scalar[dtype] = 0

    @parameter
    for h in range(hidden_dim):
        val = input[batch_idx, seq_idx, h]
        sum_val += rebind[Scalar[dtype]](val)
        sq_sum += rebind[Scalar[dtype]](val * val)

    mean_val = sum_val / hidden_dim
    var_val = (sq_sum / hidden_dim) - (mean_val * mean_val)
    inv_std = 1.0 / sqrt(var_val + 1e-5)

    # Step 2: Compute all outputs for this sequence position
    @parameter
    for out_idx in range(output_dim):
        var acc: Scalar[dtype] = 0

        @parameter
        for h in range(hidden_dim):
            input_val = input[batch_idx, seq_idx, h]
            normalized = (input_val - mean_val) * inv_std * rebind[
                Scalar[dtype]
            ](ln_weight[h]) + rebind[Scalar[dtype]](ln_bias[h])
            acc += rebind[Scalar[dtype]](normalized * linear_weight[out_idx, h])

        output[batch_idx, seq_idx, out_idx] = acc + rebind[Scalar[dtype]](
            linear_bias[out_idx]
        )


The fused implementation combines operations efficiently:

  1. Thread organization:

    • One thread block per sequence position (grid: [batch_size, seq_len])
    • Single thread per sequence position
    • Thread indices: batch_idx = block_idx.x, seq_idx = block_idx.y
  2. LayerNorm phase:

    • Compute sum and squared sum for the sequence position
    • Calculate mean: \[\Large \mu = \frac{1}{H} \sum_{i=1}^{H} x_i \]
    • Calculate variance: \[\Large \sigma^2 = \frac{1}{H} \sum_{i=1}^{H} (x_i - \mu)^2 \]
    • Compute inverse standard deviation: \[\Large \text{inv_std} = \frac{1}{\sqrt{\sigma^2 + \epsilon}} \]
  3. Linear phase:

    • For each output dimension:
      • Compute normalized value: \[\Large \text{normalized} = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
      • Multiply with linear weight and accumulate: \[\Large \text{acc} = \sum_{h=1}^{H} \text{normalized}h \cdot W{out,h} \]
      • Add linear bias: \[\Large \text{output} = \text{acc} + b_{out} \]
    • Store result in output[batch_idx, seq_idx, out_idx]
  4. Performance optimizations:

    • Single kernel launch for both operations
    • Reuse computed statistics
    • Minimize memory traffic
    • No intermediate tensor allocations
    • Efficient memory access patterns

This implementation achieves better performance than the unfused version by reducing memory bandwidth usage and kernel launch overhead.

Advantages of kernel fusion

In this puzzle, we’ve explored two approaches to implementing LayerNorm + Linear operations:

  1. Unfused implementation:

    • Separate kernels for LayerNorm and Linear
    • Simpler implementation but less efficient
    • Higher memory bandwidth usage
    • Multiple kernel launches
    • Benchmark results: 3183.57ms (GPU)
  2. Fused implementation:

    • Single kernel combining both operations
    • More complex but significantly more efficient
    • Reduced memory bandwidth usage
    • Single kernel launch
    • Benchmark results: 3116.11ms (GPU)

Memory bandwidth optimization

  1. Eliminated memory traffic:

    • No intermediate tensor allocations between operations
    • Reduced global memory reads/writes
    • Reuse of normalized values for linear transformation
    • Memory bandwidth reduction: \[\Large \text{reduction} = \frac{\text{unfused_bandwidth} - \text{fused_bandwidth}}{\text{unfused_bandwidth}}\]
  2. Cache efficiency:

    • Better L1/L2 cache utilization
    • Reduced cache misses
    • Improved memory access patterns
    • Higher arithmetic intensity

Reduced overhead

  1. Kernel launch optimization:

    • Single kernel launch instead of multiple
    • Lower driver overhead
    • Reduced synchronization points
    • Fewer memory allocations
  2. Resource management:

    • Shared memory reuse between operations
    • Better register utilization
    • Improved thread occupancy
    • Higher GPU utilization

Performance characteristics

  1. Scalability:

    • Better performance scaling with input size
    • Reduced memory bandwidth bottleneck
    • More efficient use of GPU resources
    • Improved throughput for large models
  2. Numerical efficiency:

    • Maintained numerical stability
    • Reduced rounding errors
    • Better precision in intermediate results
    • Optimized computation order

💡 Key insight: Kernel fusion is particularly beneficial for operations that are frequently used together in neural networks, like LayerNorm + Linear in transformer architectures. The performance benefits become more significant with larger input sizes and more complex models.