⚛️ Fused vs Unfused Kernels
Overview
In this puzzle, we explore the performance benefits of kernel fusion by implementing and comparing two approaches to the LayerNorm and Linear operation:
- Unfused approach: Executes LayerNorm and Linear as separate operations
- Fused kernel: Combines LayerNorm and Linear operations into a single GPU kernel
This comparison demonstrates how kernel fusion can significantly improve performance by:
- Reducing memory bandwidth usage
- Minimizing kernel launch overhead
- Improving cache utilization
- Eliminating intermediate memory allocations
Key concepts
In this puzzle, you’ll master:
- Kernel fusion techniques for combining multiple operations
- Memory bandwidth optimization through fused operations
- Performance benchmarking of different kernel implementations
- Numerical stability in fused operations
- PyTorch custom operation integration
The mathematical operations we’re fusing are:
-
LayerNorm: \[\Large \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
-
Linear: \[\Large \text{Linear}(x) = Wx + b \]
When fused, we compute: \[\Large \text{Fused}(x) = W(\gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta) + b \]
Understanding LayerNorm
LayerNorm is a normalization technique that helps stabilize and accelerate the training of deep neural networks. Let’s break down its components and parameters:
What LayerNorm does
-
Normalization: LayerNorm normalizes the activations across the features (hidden dimensions) for each sample independently. This means:
- For each sequence position, it computes statistics across the hidden dimension
- Each sample in the batch is normalized independently
- This is different from BatchNorm, which normalizes across the batch dimension
-
Parameters:
- \(\gamma\) (scale): A learnable parameter vector that allows the network to learn the optimal scale for each feature
- \(\beta\) (shift): A learnable parameter vector that allows the network to learn the optimal shift for each feature
- \(\epsilon\): A small constant (1e-5) added to the variance to prevent division by zero
What LayerNorm does in practice
LayerNorm performs several crucial functions in deep neural networks:
-
Feature standardization:
- Transforms each feature to have zero mean and unit variance
- Makes the network’s learning process more stable
- Helps prevent the “internal covariate shift” problem where the distribution of layer inputs changes during training
-
Gradient flow:
- Improves gradient flow through the network
- Prevents vanishing/exploding gradients
- Makes training more efficient by allowing higher learning rates
-
Regularization effect:
- Acts as a form of implicit regularization
- Helps prevent overfitting by normalizing the feature distributions
- Makes the network more robust to input variations
-
Sequence modeling:
- Particularly effective in transformer architectures
- Helps maintain consistent signal magnitude across different sequence lengths
- Enables better handling of variable-length sequences
-
Training dynamics:
- Accelerates training convergence
- Reduces the need for careful learning rate tuning
- Makes the network less sensitive to weight initialization
Mathematical components
-
Mean Calculation (\(\mu\)): \[\Large \mu = \frac{1}{H} \sum_{i=1}^{H} x_i \]
- Computes the mean across the hidden dimension (H)
- Each sequence position has its own mean
-
Variance Calculation (\(\sigma^2\)): \[\Large \sigma^2 = \frac{1}{H} \sum_{i=1}^{H} (x_i - \mu)^2 \]
- Computes the variance across the hidden dimension
- Used to scale the normalized values
-
Normalization and Scaling: \[\Large \text{LayerNorm}(x) = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
- First normalizes the input to have zero mean and unit variance
- Then applies learnable scale (\(\gamma\)) and shift (\(\beta\)) parameters
- The \(\odot\) symbol represents elementwise multiplication (Hadamard product)
- For example, if \(\gamma = [1.2, 0.8, 1.5]\) and normalized input is \([0.5, -0.3, 0.7]\), then \(\gamma \odot x = [0.6, -0.24, 1.05]\)
Why LayerNorm is important
-
Training Stability:
- Prevents activations from growing too large or small
- Helps maintain consistent signal magnitude throughout the network
-
Feature Learning:
- The scale (\(\gamma\)) and shift (\(\beta\)) parameters allow the network to learn which features are important
- Can effectively learn to ignore or emphasize certain features
-
Independence:
- Unlike BatchNorm, LayerNorm’s statistics are computed independently for each sample
- Makes it more suitable for variable-length sequences and small batch sizes
Configuration
- Batch size:
BATCH_SIZE = 4
- Sequence length:
SEQ_LEN = 4
- Hidden dimension:
HIDDEN_DIM = 8
- Output dimension:
OUTPUT_DIM = 16
- Epsilon:
EPS = 1e-5
- Data type:
DType.float32
Implementation approaches
1. Unfused implementation
The unfused approach executes operations separately using multiple kernels. Here are some of the kernels we wrote in the previous chapters:
Matrix multiplication kernel
From Puzzle 14, we reuse the tiled matrix multiplication kernel for the linear transformation. This kernel includes bounds checking to handle variable matrix dimensions safely:
# Idiomatic tiled matmul from p14.mojo - adapted for [batch*seq, hidden] @ [hidden, output] -> [batch*seq, output]
fn matmul_idiomatic_tiled[
a_layout: Layout,
b_layout: Layout,
out_layout: Layout,
rows: Int,
cols: Int,
inner_dim: Int,
](
output: LayoutTensor[mut=True, dtype, out_layout],
a: LayoutTensor[mut=False, dtype, a_layout],
b: LayoutTensor[mut=False, dtype, b_layout],
):
"""Idiomatic tiled matmul following p14.mojo exactly."""
local_row = thread_idx.x
local_col = thread_idx.y
tiled_row = block_idx.y * TPB + local_row
tiled_col = block_idx.x * TPB + local_col
# Get the tile of the output matrix that this thread block is responsible for
out_tile = output.tile[TPB, TPB](block_idx.x, block_idx.y)
a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc().fill(0)
b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc().fill(0)
var acc: output.element_type = 0
alias load_a_layout = Layout.row_major(1, TPB)
alias load_b_layout = Layout.row_major(TPB, 1)
for idx in range((inner_dim + TPB - 1) // TPB):
# Get tiles from A and B matrices
a_tile = a.tile[TPB, TPB](block_idx.x, idx)
b_tile = b.tile[TPB, TPB](idx, block_idx.y)
# Asynchronously copy tiles to shared memory
copy_dram_to_sram_async[thread_layout=load_a_layout](a_shared, a_tile)
copy_dram_to_sram_async[thread_layout=load_b_layout](b_shared, b_tile)
# Wait for all async copies to complete
async_copy_wait_all()
barrier()
# Compute partial matrix multiplication for this tile
@parameter
for k in range(TPB):
acc += a_shared[local_row, k] * b_shared[k, local_col]
barrier()
# Write final result with bounds checking (needed for variable matrix sizes)
if tiled_row < rows and tiled_col < cols:
out_tile[local_row, local_col] = acc
Transpose kernel
For efficient memory access patterns, we use a transpose kernel with shared memory tiling:
fn transpose_kernel[
layout_in: Layout,
layout_out: Layout,
rows: Int,
cols: Int,
](
output: LayoutTensor[mut=True, dtype, layout_out],
input: LayoutTensor[mut=False, dtype, layout_in],
):
"""Transpose matrix using shared memory tiling for coalesced access.
We will learn more about coalesced access in the next part.
"""
shared_tile = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
local_row = thread_idx.y
local_col = thread_idx.x
global_row = block_idx.y * TPB + local_row
global_col = block_idx.x * TPB + local_col
if global_row < rows and global_col < cols:
shared_tile[local_row, local_col] = input[global_row, global_col]
else:
shared_tile[local_row, local_col] = 0.0
barrier()
out_row = block_idx.x * TPB + local_row
out_col = block_idx.y * TPB + local_col
# Store data from shared memory to global memory (coalesced write)
# Note: we transpose the shared memory access pattern
if out_row < cols and out_col < rows:
output[out_row, out_col] = shared_tile[local_col, local_row]
Bias addition kernel
A simple elementwise addition kernel for adding the bias term:
fn add_bias_kernel[
input_layout: Layout,
bias_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
output_dim: Int,
](
output: LayoutTensor[mut=True, dtype, output_layout],
input: LayoutTensor[mut=False, dtype, input_layout],
bias: LayoutTensor[mut=False, dtype, bias_layout],
):
"""Simple bias addition."""
batch_idx = block_idx.x
seq_idx = block_idx.y
out_idx = thread_idx.x
if batch_idx >= batch_size or seq_idx >= seq_len or out_idx >= output_dim:
return
output[batch_idx, seq_idx, out_idx] = input[
batch_idx, seq_idx, out_idx
] + rebind[Scalar[dtype]](bias[out_idx])
LayerNorm kernel
Now complete this kernel to implement the LayerNorm operation. You’ll need to:
- Compute mean \(\mu\) and variance \(\sigma^2\) for each sequence position
- Normalize the input using these statistics
- Apply the scale \(\gamma\) and shift \(\beta\) parameters
fn layernorm_kernel[
input_layout: Layout,
ln_params_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
hidden_dim: Int,
](
output: LayoutTensor[mut=True, dtype, output_layout],
input: LayoutTensor[mut=False, dtype, input_layout],
ln_weight: LayoutTensor[mut=False, dtype, ln_params_layout],
ln_bias: LayoutTensor[mut=False, dtype, ln_params_layout],
):
batch_idx = block_idx.x
seq_idx = block_idx.y
hidden_idx = thread_idx.x
if (
batch_idx >= batch_size
or seq_idx >= seq_len
or hidden_idx >= hidden_dim
):
return
# Compute statistics for this sequence position (redundant but simple)
var sum_val: Scalar[dtype] = 0
var sq_sum: Scalar[dtype] = 0
# FILL ME IN (roughly 11 lines)
Implementation steps:
- First, compute mean and variance using parallel reduction
- Then normalize the input using these statistics
- Finally, apply the scale and shift parameters
Characteristics of unfused approach:
- Multiple kernel launches (LayerNorm → MatMul → Bias)
- Intermediate tensor allocations between operations
- More memory bandwidth usage due to separate passes
- Simpler implementation with clear separation of concerns
- Easier to debug as each operation is isolated
Tips
-
Thread organization:
- Use one thread block per sequence position (grid:
[batch_size, seq_len]
) - Each thread handles one hidden dimension element
- Avoid redundant computation by computing statistics once per sequence
- Use one thread block per sequence position (grid:
-
Memory access:
- Access input tensor with
[batch_idx, seq_idx, hidden_idx]
- Access output tensor with
[batch_idx, seq_idx, hidden_idx]
- Access LayerNorm parameters with
[hidden_idx]
- Access input tensor with
-
Numerical stability:
- Add epsilon (1e-5) before taking square root
- Use
rebind[Scalar[dtype]]
for proper type casting - Compute variance as (sq_sum / hidden_dim) - (mean * mean)
-
Performance:
- Compute mean and variance in a single pass
- Reuse computed statistics for all elements in sequence
- Avoid unnecessary memory barriers
Running the code
To test your unfused implementation, run:
uv run poe p20 --unfused
pixi run p20 --unfused
Your output will look like this:
Testing with dimensions: [4, 4, 8] -> [4, 4, 16]
✅ Loaded Mojo operations library
============================================================
Puzzle 20: UNFUSED Algorithm Test & Benchmark
============================================================
🧪 Correctness Testing for UNFUSED Algorithm
====================================================
Testing Reference PyTorch Implementation
-----------------------------------------------
✅ Reference PyTorch
Max difference: 0.00e+00
Result: ✅ CORRECT
Testing CPU Implementation
---------------------------------
✅ Using Mojo fused kernel (CPU)
Max difference: 1.86e-08
Result: ✅ CORRECT
Testing GPU Unfused Implementation
-----------------------------------------
✅ Using Mojo unfused kernel (GPU)
Max difference: 1.86e-08
Result: ✅ CORRECT
Correctness Summary:
- Reference: ✅ CORRECT
- CPU: ✅ CORRECT
- GPU unfused: ✅ CORRECT
Overall Correctness: ✅ ALL CORRECT
Benchmarking CPU vs GPU UNFUSED
------------------------------------------
Testing CPU performance...
CPU: 3173.70ms (50 iterations)
Testing GPU unfused performance...
GPU unfused: 3183.57ms (50 iterations)
GPU unfused vs CPU: 1.00x slower
CPU wins (GPU overhead > computation benefit)
UNFUSED Algorithm Test Completed!
Solution
fn layernorm_kernel[
input_layout: Layout,
ln_params_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
hidden_dim: Int,
](
output: LayoutTensor[mut=True, dtype, output_layout],
input: LayoutTensor[mut=False, dtype, input_layout],
ln_weight: LayoutTensor[mut=False, dtype, ln_params_layout],
ln_bias: LayoutTensor[mut=False, dtype, ln_params_layout],
):
batch_idx = block_idx.x
seq_idx = block_idx.y
hidden_idx = thread_idx.x
if (
batch_idx >= batch_size
or seq_idx >= seq_len
or hidden_idx >= hidden_dim
):
return
# Compute statistics for this sequence position (redundant but simple)
var sum_val: Scalar[dtype] = 0
var sq_sum: Scalar[dtype] = 0
@parameter
for h in range(hidden_dim):
val = input[batch_idx, seq_idx, h]
sum_val += rebind[Scalar[dtype]](val)
sq_sum += rebind[Scalar[dtype]](val * val)
mean_val = sum_val / hidden_dim
var_val = (sq_sum / hidden_dim) - (mean_val * mean_val)
inv_std = 1.0 / sqrt(var_val + 1e-5)
# Apply LayerNorm to this element
input_val = input[batch_idx, seq_idx, hidden_idx]
normalized = (input_val - mean_val) * inv_std * rebind[Scalar[dtype]](
ln_weight[hidden_idx]
) + rebind[Scalar[dtype]](ln_bias[hidden_idx])
output[batch_idx, seq_idx, hidden_idx] = normalized
The unfused implementation follows a straightforward approach where each thread handles one element of the output tensor. Let’s break down the key components:
-
Thread and Block Organization:
batch_idx = block_idx.x seq_idx = block_idx.y hidden_idx = thread_idx.x
- Each thread block handles one sequence position in the batch
- Grid dimensions:
[batch_size, seq_len]
- Each thread processes one element in the hidden dimension
- Early return if indices are out of bounds:
if (batch_idx >= batch_size or seq_idx >= seq_len or hidden_idx >= hidden_dim): return
-
Statistics Computation:
var sum_val: Scalar[dtype] = 0 var sq_sum: Scalar[dtype] = 0 @parameter for h in range(hidden_dim): val = input[batch_idx, seq_idx, h] sum_val += rebind[Scalar[dtype]](val) sq_sum += rebind[Scalar[dtype]](val * val)
- Compute sum and squared sum in a single pass
- Use
@parameter
for compile-time loop unrolling - Proper type casting with
rebind[Scalar[dtype]]
- Calculate mean and variance:
mean_val = sum_val / hidden_dim var_val = (sq_sum / hidden_dim) - (mean_val * mean_val) inv_std = 1.0 / sqrt(var_val + 1e-5)
-
Normalization and Scaling:
input_val = input[batch_idx, seq_idx, hidden_idx] normalized = (input_val - mean_val) * inv_std * rebind[Scalar[dtype]]( ln_weight[hidden_idx] ) + rebind[Scalar[dtype]](ln_bias[hidden_idx]) output[batch_idx, seq_idx, hidden_idx] = normalized
- Apply normalization: \[\Large \text{normalized} = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
- Scale with learnable parameter
γ
(ln_weight) - Add learnable bias
β
(ln_bias) - Store result in output tensor
-
Performance Characteristics:
- Each thread computes statistics independently
- No shared memory usage (simple but less efficient)
- Memory access pattern:
- Input:
[batch_idx, seq_idx, h]
- Output:
[batch_idx, seq_idx, hidden_idx]
- Parameters:
[hidden_idx]
- Input:
- Numerical stability ensured by:
- Adding epsilon (1e-5) before square root
- Using proper type casting
- Computing variance in a numerically stable way
-
Implementation Details:
-
Type Safety:
- Use
Scalar[dtype]
for intermediate calculations rebind[Scalar[dtype]]
for proper type casting- Ensures consistent floating-point precision
- Use
-
Memory Access:
- Coalesced reads from input tensor
- Coalesced writes to output tensor
- Sequential access to LayerNorm parameters
-
Computation Flow:
- Statistics computation: \[\Large O(H) \text{ operations per thread} \]
- Normalization: \[\Large O(1) \text{ operations per thread} \]
- Total complexity: \[\Large O(H) \text{ per output element} \]
-
Limitations:
- Redundant computation of statistics
- No shared memory for intermediate results
- High memory bandwidth usage
- Multiple kernel launches required
-
This implementation is correct but not optimal for performance, as shown in the benchmark results where it’s slightly slower than the CPU version. The fused implementation will address these performance limitations by:
- Computing statistics once per sequence
- Reusing normalized values
- Reducing memory traffic
- Eliminating intermediate tensor allocations
2. Fused kernel implementation
The fused kernel combines LayerNorm and Linear operations into a single GPU kernel:
fn minimal_fused_kernel[
input_layout: Layout,
ln_params_layout: Layout,
weight_layout: Layout,
bias_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
hidden_dim: Int,
output_dim: Int,
](
output: LayoutTensor[mut=True, dtype, output_layout],
input: LayoutTensor[mut=False, dtype, input_layout],
ln_weight: LayoutTensor[mut=False, dtype, ln_params_layout],
ln_bias: LayoutTensor[mut=False, dtype, ln_params_layout],
linear_weight: LayoutTensor[mut=False, dtype, weight_layout],
linear_bias: LayoutTensor[mut=False, dtype, bias_layout],
):
"""Minimal fused kernel - one thread per sequence position to avoid redundancy.
"""
# Grid: (batch_size, seq_len) - one thread block per sequence position
# Block: (1,) - single thread per sequence position to avoid redundant computation
batch_idx = block_idx.x
seq_idx = block_idx.y
if batch_idx >= batch_size or seq_idx >= seq_len:
return
# Step 1: Compute LayerNorm statistics once per sequence position
# FILL IN roughly 10 lines
# Step 2: Compute all outputs for this sequence position
# FILL IN roughly 10 lines
Key optimizations:
- Single kernel launch instead of two
- Shared memory for intermediate results
- Coalesced memory access patterns
- Reduced memory bandwidth usage
- No intermediate tensor allocations
Tips
-
Thread organization:
- One thread block per sequence position (grid:
[batch_size, seq_len]
) - Single thread per sequence position to avoid redundancy
- Compute all outputs for each sequence position in one thread
- One thread block per sequence position (grid:
-
Memory access:
- Access input tensor with
[batch_idx, seq_idx, h]
- Access output tensor with
[batch_idx, seq_idx, out_idx]
- Access weights with
[out_idx, h]
for linear layer
- Access input tensor with
-
Computation flow:
- Compute LayerNorm statistics once per sequence
- Reuse normalized values for all output dimensions
- Combine normalization and linear transformation
-
Performance:
- Avoid redundant computation of statistics
- Minimize memory traffic by fusing operations
- Use proper type casting with
rebind[Scalar[dtype]]
Running the code
To test your fused implementation, run:
uv run poe p20 --fused
pixi run p20 --fused
Your output will look like this:
Testing with dimensions: [4, 4, 8] -> [4, 4, 16]
✅ Loaded Mojo operations library
============================================================
Puzzle 20: FUSED Algorithm Test & Benchmark
============================================================
🧪 Correctness Testing for FUSED Algorithm
==================================================
Testing Reference PyTorch Implementation
-----------------------------------------------
✅ Reference PyTorch
Max difference: 0.00e+00
Result: ✅ CORRECT
Testing CPU Implementation
---------------------------------
✅ Using Mojo fused kernel (CPU)
Max difference: 1.86e-08
Result: ✅ CORRECT
Testing GPU Fused Implementation
---------------------------------------
✅ Using Mojo fused kernel (GPU)
Max difference: 1.86e-08
Result: ✅ CORRECT
Correctness Summary:
- Reference: ✅ CORRECT
- CPU: ✅ CORRECT
- GPU fused: ✅ CORRECT
Overall Correctness: ✅ ALL CORRECT
⚡ Benchmarking CPU vs GPU FUSED
----------------------------------------
Testing CPU performance...
CPU: 3144.75ms (50 iterations)
Testing GPU fused performance...
GPU fused: 3116.11ms (50 iterations)
GPU fused vs CPU: 1.01x faster
GPU fused wins!
FUSED Algorithm Test Completed!
Solution
fn minimal_fused_kernel[
input_layout: Layout,
ln_params_layout: Layout,
weight_layout: Layout,
bias_layout: Layout,
output_layout: Layout,
batch_size: Int,
seq_len: Int,
hidden_dim: Int,
output_dim: Int,
](
output: LayoutTensor[mut=True, dtype, output_layout],
input: LayoutTensor[mut=False, dtype, input_layout],
ln_weight: LayoutTensor[mut=False, dtype, ln_params_layout],
ln_bias: LayoutTensor[mut=False, dtype, ln_params_layout],
linear_weight: LayoutTensor[mut=False, dtype, weight_layout],
linear_bias: LayoutTensor[mut=False, dtype, bias_layout],
):
"""Minimal fused kernel - one thread per sequence position to avoid redundancy.
"""
# Grid: (batch_size, seq_len) - one thread block per sequence position
# Block: (1,) - single thread per sequence position to avoid redundant computation
batch_idx = block_idx.x
seq_idx = block_idx.y
if batch_idx >= batch_size or seq_idx >= seq_len:
return
# Step 1: Compute LayerNorm statistics once per sequence position
var sum_val: Scalar[dtype] = 0
var sq_sum: Scalar[dtype] = 0
@parameter
for h in range(hidden_dim):
val = input[batch_idx, seq_idx, h]
sum_val += rebind[Scalar[dtype]](val)
sq_sum += rebind[Scalar[dtype]](val * val)
mean_val = sum_val / hidden_dim
var_val = (sq_sum / hidden_dim) - (mean_val * mean_val)
inv_std = 1.0 / sqrt(var_val + 1e-5)
# Step 2: Compute all outputs for this sequence position
@parameter
for out_idx in range(output_dim):
var acc: Scalar[dtype] = 0
@parameter
for h in range(hidden_dim):
input_val = input[batch_idx, seq_idx, h]
normalized = (input_val - mean_val) * inv_std * rebind[
Scalar[dtype]
](ln_weight[h]) + rebind[Scalar[dtype]](ln_bias[h])
acc += rebind[Scalar[dtype]](normalized * linear_weight[out_idx, h])
output[batch_idx, seq_idx, out_idx] = acc + rebind[Scalar[dtype]](
linear_bias[out_idx]
)
The fused implementation combines operations efficiently:
-
Thread organization:
- One thread block per sequence position (grid:
[batch_size, seq_len]
) - Single thread per sequence position
- Thread indices:
batch_idx = block_idx.x
,seq_idx = block_idx.y
- One thread block per sequence position (grid:
-
LayerNorm phase:
- Compute sum and squared sum for the sequence position
- Calculate mean: \[\Large \mu = \frac{1}{H} \sum_{i=1}^{H} x_i \]
- Calculate variance: \[\Large \sigma^2 = \frac{1}{H} \sum_{i=1}^{H} (x_i - \mu)^2 \]
- Compute inverse standard deviation: \[\Large \text{inv_std} = \frac{1}{\sqrt{\sigma^2 + \epsilon}} \]
-
Linear phase:
- For each output dimension:
- Compute normalized value: \[\Large \text{normalized} = \gamma \odot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta \]
- Multiply with linear weight and accumulate: \[\Large \text{acc} = \sum_{h=1}^{H} \text{normalized}h \cdot W{out,h} \]
- Add linear bias: \[\Large \text{output} = \text{acc} + b_{out} \]
- Store result in
output[batch_idx, seq_idx, out_idx]
- For each output dimension:
-
Performance optimizations:
- Single kernel launch for both operations
- Reuse computed statistics
- Minimize memory traffic
- No intermediate tensor allocations
- Efficient memory access patterns
This implementation achieves better performance than the unfused version by reducing memory bandwidth usage and kernel launch overhead.
Advantages of kernel fusion
In this puzzle, we’ve explored two approaches to implementing LayerNorm + Linear operations:
-
Unfused implementation:
- Separate kernels for LayerNorm and Linear
- Simpler implementation but less efficient
- Higher memory bandwidth usage
- Multiple kernel launches
- Benchmark results: 3183.57ms (GPU)
-
Fused implementation:
- Single kernel combining both operations
- More complex but significantly more efficient
- Reduced memory bandwidth usage
- Single kernel launch
- Benchmark results: 3116.11ms (GPU)
Memory bandwidth optimization
-
Eliminated memory traffic:
- No intermediate tensor allocations between operations
- Reduced global memory reads/writes
- Reuse of normalized values for linear transformation
- Memory bandwidth reduction: \[\Large \text{reduction} = \frac{\text{unfused_bandwidth} - \text{fused_bandwidth}}{\text{unfused_bandwidth}}\]
-
Cache efficiency:
- Better L1/L2 cache utilization
- Reduced cache misses
- Improved memory access patterns
- Higher arithmetic intensity
Reduced overhead
-
Kernel launch optimization:
- Single kernel launch instead of multiple
- Lower driver overhead
- Reduced synchronization points
- Fewer memory allocations
-
Resource management:
- Shared memory reuse between operations
- Better register utilization
- Improved thread occupancy
- Higher GPU utilization
Performance characteristics
-
Scalability:
- Better performance scaling with input size
- Reduced memory bandwidth bottleneck
- More efficient use of GPU resources
- Improved throughput for large models
-
Numerical efficiency:
- Maintained numerical stability
- Reduced rounding errors
- Better precision in intermediate results
- Optimized computation order
💡 Key insight: Kernel fusion is particularly beneficial for operations that are frequently used together in neural networks, like LayerNorm + Linear in transformer architectures. The performance benefits become more significant with larger input sizes and more complex models.