Embedding Kernels: Coaleasced vs non-Coaleasced
In this puzzle, you’ll implement two different GPU kernels for embedding operations that produce identical results but use different memory access patterns, demonstrating the critical importance of memory coalescing in GPU performance.
1D coalesced kernel (optimized approach)
This kernel uses a simple 1D grid where each thread processes exactly one output element. The key insight is that consecutive threads will access consecutive memory locations, leading to optimal memory coalescing.
Thread organization:
- Grid configuration:
[total_elements // 256]
blocks,256
threads per block - Thread mapping: Each thread handles one
(batch, seq, embed)
position - Memory pattern: Consecutive threads access consecutive embedding dimensions
What you need to implement:
- Calculate the global thread index from block and thread indices
- Convert the flat index to 3D coordinates
(batch_idx, seq_idx, embed_idx)
- Look up the token index from the indices tensor
- Copy the appropriate embedding vector element to the output
Code to complete
You need to complete the missing parts in both embedding kernels:
alias THREADS_PER_BLOCK = 256
fn embedding_kernel_coalesced[
indices_layout: Layout,
weights_layout: Layout,
out_layout: Layout,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[mut=True, dtype, out_layout],
indices: LayoutTensor[mut=True, DType.int32, indices_layout],
weights: LayoutTensor[mut=True, dtype, weights_layout],
):
"""
Memory-coalescing focused embedding kernel.
Key insight: The bottleneck is memory access patterns, not computation.
- Each thread handles one (batch, seq, embed) position
- Simple 1D grid for maximum simplicity and correctness
- Focus on getting memory access right first
"""
# Simple 1D indexing - each thread = one output element
global_idx = block_idx.x * block_dim.x + thread_idx.x
total_elements = batch_size * seq_len * embed_dim
if global_idx >= total_elements:
return
# Convert to (batch, seq, embed) coordinates
# FILL IN roughly 4 lines
# Get token index
# FILL IN 1 line
# Simple, correct assignment
# FILL IN 4 lines
View full file: problems/p19/op/embedding.mojo
Tips
- Start with
global_idx = block_idx.x * block_dim.x + thread_idx.x
- Convert to 3D coordinates using division and modulo:
batch_idx = global_idx // (seq_len * embed_dim)
- Use
remaining = global_idx % (seq_len * embed_dim)
to simplify further calculations - Always check bounds:
if global_idx >= total_elements: return
- Handle invalid token indices by setting output to 0
- The embedding lookup is:
output[batch_idx, seq_idx, embed_idx] = weights[token_idx, embed_idx]
2D non-coalesced kernel (comparison approach)
This kernel uses a 2D grid where the X dimension spans (batch × seq)
positions and the Y dimension spans embedding dimensions. This can lead to non-coalesced memory access patterns.
Thread organization:
- Grid configuration:
[batch x seq // 16, embed_dim // 16]
blocks,16 x 16
threads per block - Thread mapping:
thread_idx.x
maps to batch/sequence,thread_idx.y
maps to embedding dimension - Memory pattern: Threads in a warp may access scattered memory locations
What you need to implement:
- Calculate both X and Y coordinates from the 2D grid
- Convert the X coordinate to separate batch and sequence indices
- Use the Y coordinate directly as the embedding dimension
- Perform the same embedding lookup with bounds checking
Code to complete
You need to complete the missing parts in both embedding kernels:
fn embedding_kernel_2d[
indices_layout: Layout,
weights_layout: Layout,
out_layout: Layout,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[mut=True, dtype, out_layout],
indices: LayoutTensor[mut=True, DType.int32, indices_layout],
weights: LayoutTensor[mut=True, dtype, weights_layout],
):
"""
2D grid non-coalesced embedding kernel.
Non-optimal approach for comparison:
- 2D grid: (batch*seq, embed_dim)
- More complex indexing
- Potentially worse memory access patterns
"""
# 2D grid indexing
batch_seq_idx = block_idx.x * block_dim.x + thread_idx.x
embed_idx = block_idx.y * block_dim.y + thread_idx.y
total_positions = batch_size * seq_len
if batch_seq_idx >= total_positions or embed_idx >= embed_dim:
return
# Convert to (batch, seq) coordinates
# FILL IN 2 lines
# Get token index
# FILL IN 1 line
# Assignment with 2D grid pattern
# FILL IN 4 lines
View full file: problems/p19/op/embedding.mojo
Tips
- Use both X and Y thread coordinates:
batch_seq_idx = block_idx.x * block_dim.x + thread_idx.x
- And:
embed_idx = block_idx.y * block_dim.y + thread_idx.y
- Convert
batch_seq_idx
to separate batch and sequence indices:batch_idx = batch_seq_idx // seq_len
- Remember to check bounds for both dimensions:
if batch_seq_idx >= total_positions or embed_idx >= embed_dim
- The token lookup is the same as 1D, but you’re only handling one embedding dimension per thread
- This kernel processes one embedding dimension per thread instead of entire vectors
Custom ops registration
The kernels are wrapped in PyTorch custom operations for easy integration. The registration pattern is the same as MAX custom ops explained in Understanding MAX Graph custom ops:
1D coalesced operation
This operation registers the optimized 1D embedding kernel as "embedding"
:
import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor
from memory import UnsafePointer
from gpu.host import DeviceBuffer
@compiler.register("embedding")
struct EmbeddingCustomOp:
@staticmethod
fn execute[
target: StaticString,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
](
output: OutputTensor[
dtype = DType.float32, rank=3
], # [batch_size, seq_len, embed_dim]
indices: InputTensor[
dtype = DType.int32, rank=2
], # [batch_size, seq_len]
weights: InputTensor[
dtype = output.dtype, rank=2
], # [vocab_size, embed_dim]
ctx: DeviceContextPtr,
) raises:
output_tensor = output.to_layout_tensor()
indices_tensor = indices.to_layout_tensor()
weights_tensor = weights.to_layout_tensor()
alias indices_layout = indices_tensor.layout
alias weights_layout = weights_tensor.layout
alias out_layout = output_tensor.layout
@parameter
if target == "gpu":
gpu_ctx = ctx.get_device_context()
# Zero out output tensor
gpu_ctx.enqueue_memset(
DeviceBuffer[output.dtype](
gpu_ctx,
rebind[UnsafePointer[Scalar[output.dtype]]](
output_tensor.ptr
),
batch_size * seq_len * embed_dim,
owning=False,
),
0,
)
# Calculate 1D grid dimensions (matching kernel's flat indexing)
total_elements = batch_size * seq_len * embed_dim
blocks = max(1, ceildiv(total_elements, THREADS_PER_BLOCK))
# Compile and launch optimized kernel
compiled_kernel = gpu_ctx.compile_function[
embedding_kernel_coalesced[
indices_layout,
weights_layout,
out_layout,
batch_size,
seq_len,
vocab_size,
embed_dim,
output.dtype,
]
]()
gpu_ctx.enqueue_function(
compiled_kernel,
output_tensor,
indices_tensor,
weights_tensor,
grid_dim=(blocks,),
block_dim=(THREADS_PER_BLOCK,),
)
elif target == "cpu":
for batch in range(batch_size):
for seq in range(seq_len):
token_idx_val = Int(indices_tensor[batch, seq])
if token_idx_val >= 0 and token_idx_val < vocab_size:
for emb in range(embed_dim):
output_tensor[batch, seq, emb] = weights_tensor[
token_idx_val, emb
]
else:
raise Error("Unsupported target: " + target)
Key aspects of this registration:
- Simple grid configuration: Uses a straightforward 1D grid with
ceildiv(total_elements, THREADS_PER_BLOCK)
blocks - Memory optimization: Single
enqueue_memset
call to zero the output buffer efficiently - Compile-time parameters: All tensor dimensions passed as compile-time parameters for optimal performance
- Device abstraction: Handles both GPU execution and CPU fallback seamlessly
2D non-coalesced operation
This operation registers the comparison 2D embedding kernel as "embedding_2d"
:
@compiler.register("embedding_2d")
struct Embedding2DCustomOp:
@staticmethod
fn execute[
target: StaticString,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
](
output: OutputTensor[
dtype = DType.float32, rank=3
], # [batch_size, seq_len, embed_dim]
indices: InputTensor[
dtype = DType.int32, rank=2
], # [batch_size, seq_len]
weights: InputTensor[
dtype = output.dtype, rank=2
], # [vocab_size, embed_dim]
ctx: DeviceContextPtr,
) raises:
output_tensor = output.to_layout_tensor()
indices_tensor = indices.to_layout_tensor()
weights_tensor = weights.to_layout_tensor()
alias indices_layout = indices_tensor.layout
alias weights_layout = weights_tensor.layout
alias out_layout = output_tensor.layout
@parameter
if target == "gpu":
gpu_ctx = ctx.get_device_context()
# Zero out output tensor
gpu_ctx.enqueue_memset(
DeviceBuffer[output.dtype](
gpu_ctx,
rebind[UnsafePointer[Scalar[output.dtype]]](
output_tensor.ptr
),
batch_size * seq_len * embed_dim,
owning=False,
),
0,
)
# Calculate 2D grid dimensions for non-coalesced access
total_positions = batch_size * seq_len
alias BLOCK_X = 16 # batch*seq dimension
alias BLOCK_Y = 16 # embed dimension
blocks_x = max(1, ceildiv(total_positions, BLOCK_X))
blocks_y = max(1, ceildiv(embed_dim, BLOCK_Y))
# Compile and launch 2D kernel
compiled_kernel = gpu_ctx.compile_function[
embedding_kernel_2d[
indices_layout,
weights_layout,
out_layout,
batch_size,
seq_len,
vocab_size,
embed_dim,
output.dtype,
]
]()
gpu_ctx.enqueue_function(
compiled_kernel,
output_tensor,
indices_tensor,
weights_tensor,
grid_dim=(blocks_x, blocks_y),
block_dim=(BLOCK_X, BLOCK_Y),
)
elif target == "cpu":
# Same CPU fallback as 1D version
for batch in range(batch_size):
for seq in range(seq_len):
token_idx_val = Int(indices_tensor[batch, seq])
if token_idx_val >= 0 and token_idx_val < vocab_size:
for emb in range(embed_dim):
output_tensor[batch, seq, emb] = weights_tensor[
token_idx_val, emb
]
else:
raise Error("Unsupported target: " + target)
Key differences from the 1D operation:
- Complex grid configuration: Uses a 2D grid with separate calculations for
blocks_x
andblocks_y
- Fixed block dimensions: Hard-coded
BLOCK_X = 16
andBLOCK_Y = 16
for 2D thread organization - Same memory management: Identical memory initialization and CPU fallback logic
- Different kernel call: Passes 2D grid dimensions
(blocks_x, blocks_y)
and block dimensions(BLOCK_X, BLOCK_Y)
Common wrapper functionality
Both custom operations provide essential infrastructure:
-
Memory management:
- Zero-initialization of output tensors with
enqueue_memset
- Proper buffer creation and memory layout handling
- Automatic cleanup and resource management
- Zero-initialization of output tensors with
-
Device abstraction:
- GPU execution with optimized kernels
- CPU fallback for compatibility and debugging
- Consistent interface regardless of execution target
-
Parameter passing:
- Compile-time tensor dimensions for kernel optimization
- Runtime tensor data through layout tensor conversion
- Type-safe parameter validation
-
Grid configuration:
- Automatic calculation of optimal grid dimensions
- Different strategies optimized for each kernel’s access pattern
- Proper block dimension management
Integration with PyTorch
These registered operations can be called from Python using the CustomOpLibrary:
# Load the custom operations
ops = CustomOpLibrary(mojo_kernels)
# Call the 1D coalesced version
result_1d = ops.embedding[{"batch_size": B, "seq_len": L, "vocab_size": V, "embed_dim": E}](
indices, weights
)
# Call the 2D non-coalesced version
result_2d = ops.embedding_2d[{"batch_size": B, "seq_len": L, "vocab_size": V, "embed_dim": E}](
indices, weights
)
The power of this approach is that the same kernel implementations can be used across different Python frameworks while maintaining optimal performance characteristics.
Run the code
You can run the puzzle with:
uv run poe p19
pixi run p19
When successful, you should see output similar to:
Puzzle 19: Mojo Embedding Kernel Comparison
======================================================================
Configuration: B=8, L=512, V=10000, E=512
------------------------------------------------------------
Testing Correctness...
1D Coalesced - Max difference: 1.19e-07
2D Non-coalesced - Max difference: 1.19e-07
✅ Both implementations CORRECT
Benchmarking Mojo Kernels...
Performance Results:
1D Coalesced: 2.145 ms
2D Non-coalesced: 3.867 ms
1D is 1.80x faster than 2D
Key Learning Points:
• Compare different GPU kernel implementations
• 1D vs 2D grid patterns have different memory access
• Coalesced memory access should be faster
• Grid configuration affects GPU utilization
Solution
The solution involves implementing the coordinate transformations and memory operations for both kernels:
1D Coalesced Kernel
fn embedding_kernel_coalesced[
indices_layout: Layout,
weights_layout: Layout,
out_layout: Layout,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[mut=True, dtype, out_layout],
indices: LayoutTensor[mut=True, DType.int32, indices_layout],
weights: LayoutTensor[mut=True, dtype, weights_layout],
):
"""
Memory-coalescing focused embedding kernel.
Key insight: The bottleneck is memory access patterns, not computation.
- Each thread handles one (batch, seq, embed) position
- Simple 1D grid for maximum simplicity and correctness
- Focus on getting memory access right first
"""
# Simple 1D indexing - each thread = one output element
global_idx = block_idx.x * block_dim.x + thread_idx.x
total_elements = batch_size * seq_len * embed_dim
if global_idx >= total_elements:
return
# Convert to (batch, seq, embed) coordinates
batch_idx = global_idx // (seq_len * embed_dim)
remaining = global_idx % (seq_len * embed_dim)
seq_idx = remaining // embed_dim
embed_idx = remaining % embed_dim
# Get token index
token_idx_val = Int(indices[batch_idx, seq_idx])
# Simple, correct assignment
if token_idx_val >= 0 and token_idx_val < vocab_size:
output[batch_idx, seq_idx, embed_idx] = weights[
token_idx_val, embed_idx
]
else:
output[batch_idx, seq_idx, embed_idx] = 0
2D Non-Coalesced Kernel
fn embedding_kernel_2d[
indices_layout: Layout,
weights_layout: Layout,
out_layout: Layout,
batch_size: Int,
seq_len: Int,
vocab_size: Int,
embed_dim: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[mut=True, dtype, out_layout],
indices: LayoutTensor[mut=True, DType.int32, indices_layout],
weights: LayoutTensor[mut=True, dtype, weights_layout],
):
"""
2D grid non-coalesced embedding kernel.
Non-optimal approach for comparison:
- 2D grid: (batch*seq, embed_dim)
- More complex indexing
- Potentially worse memory access patterns
"""
# 2D grid indexing
batch_seq_idx = block_idx.x * block_dim.x + thread_idx.x
embed_idx = block_idx.y * block_dim.y + thread_idx.y
total_positions = batch_size * seq_len
# Bounds check
if batch_seq_idx >= total_positions or embed_idx >= embed_dim:
return
# Convert to (batch, seq) coordinates
batch_idx = batch_seq_idx // seq_len
seq_idx = batch_seq_idx % seq_len
# Get token index
token_idx_val = Int(indices[batch_idx, seq_idx])
# Assignment with 2D grid pattern
if token_idx_val >= 0 and token_idx_val < vocab_size:
output[batch_idx, seq_idx, embed_idx] = weights[
token_idx_val, embed_idx
]
else:
output[batch_idx, seq_idx, embed_idx] = 0
Both solutions implement the same embedding lookup logic but with different thread organizations:
Key differences
-
Thread mapping:
- 1D kernel: One thread per output element, simple flat indexing
- 2D kernel: 2D grid mapping to (batch×seq, embed_dim) coordinates
-
Memory access patterns:
- 1D kernel: Consecutive threads access consecutive embedding dimensions → coalesced
- 2D kernel: Thread access pattern depends on block configuration → potentially non-coalesced
-
Indexing complexity:
- 1D kernel: Single division/modulo chain to get 3D coordinates
- 2D kernel: Separate X/Y coordinate calculations
Performance implications
The 1D kernel typically performs better because:
- Memory coalescing: Consecutive threads access consecutive memory addresses
- Simple indexing: Lower computational overhead for coordinate calculations
- Better cache utilization: Predictable memory access patterns
The 2D kernel may perform worse due to:
- Scattered memory accesses: Threads within a warp may access different embedding vectors
- Complex grid configuration: 16×16 blocks may not align optimally with memory layout
- Warp divergence: Different threads may follow different execution paths
Key concepts
Concept | 1D Coalesced | 2D Non-coalesced |
---|---|---|
Thread organization | 1D flat indexing | 2D grid (batch×seq, embed) |
Memory access | Consecutive addresses | Potentially scattered |
Grid configuration | Simple: [total_elements // 256] | Complex: [batch×seq // 16, embed // 16] |
Performance | Optimized for memory bandwidth | Suboptimal memory pattern |
Use case | Production kernels | Educational comparison |
The core lesson: memory coalescing can lead to 2-3x performance differences for memory-bound operations like embeddings.