Embedding Kernels: Coaleasced vs non-Coaleasced

In this puzzle, you’ll implement two different GPU kernels for embedding operations that produce identical results but use different memory access patterns, demonstrating the critical importance of memory coalescing in GPU performance.

1D coalesced kernel (optimized approach)

This kernel uses a simple 1D grid where each thread processes exactly one output element. The key insight is that consecutive threads will access consecutive memory locations, leading to optimal memory coalescing.

Thread organization:

  • Grid configuration: [total_elements // 256] blocks, 256 threads per block
  • Thread mapping: Each thread handles one (batch, seq, embed) position
  • Memory pattern: Consecutive threads access consecutive embedding dimensions

What you need to implement:

  1. Calculate the global thread index from block and thread indices
  2. Convert the flat index to 3D coordinates (batch_idx, seq_idx, embed_idx)
  3. Look up the token index from the indices tensor
  4. Copy the appropriate embedding vector element to the output

Code to complete

You need to complete the missing parts in both embedding kernels:

alias THREADS_PER_BLOCK = 256


fn embedding_kernel_coalesced[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[mut=True, dtype, out_layout],
    indices: LayoutTensor[mut=True, DType.int32, indices_layout],
    weights: LayoutTensor[mut=True, dtype, weights_layout],
):
    """
    Memory-coalescing focused embedding kernel.

    Key insight: The bottleneck is memory access patterns, not computation.
    - Each thread handles one (batch, seq, embed) position
    - Simple 1D grid for maximum simplicity and correctness
    - Focus on getting memory access right first
    """

    # Simple 1D indexing - each thread = one output element
    global_idx = block_idx.x * block_dim.x + thread_idx.x
    total_elements = batch_size * seq_len * embed_dim

    if global_idx >= total_elements:
        return

    # Convert to (batch, seq, embed) coordinates
    # FILL IN roughly 4 lines

    # Get token index
    # FILL IN 1 line

    # Simple, correct assignment
    # FILL IN 4 lines


View full file: problems/p19/op/embedding.mojo

Tips
  • Start with global_idx = block_idx.x * block_dim.x + thread_idx.x
  • Convert to 3D coordinates using division and modulo: batch_idx = global_idx // (seq_len * embed_dim)
  • Use remaining = global_idx % (seq_len * embed_dim) to simplify further calculations
  • Always check bounds: if global_idx >= total_elements: return
  • Handle invalid token indices by setting output to 0
  • The embedding lookup is: output[batch_idx, seq_idx, embed_idx] = weights[token_idx, embed_idx]

2D non-coalesced kernel (comparison approach)

This kernel uses a 2D grid where the X dimension spans (batch × seq) positions and the Y dimension spans embedding dimensions. This can lead to non-coalesced memory access patterns.

Thread organization:

  • Grid configuration: [batch x seq // 16, embed_dim // 16] blocks, 16 x 16 threads per block
  • Thread mapping: thread_idx.x maps to batch/sequence, thread_idx.y maps to embedding dimension
  • Memory pattern: Threads in a warp may access scattered memory locations

What you need to implement:

  1. Calculate both X and Y coordinates from the 2D grid
  2. Convert the X coordinate to separate batch and sequence indices
  3. Use the Y coordinate directly as the embedding dimension
  4. Perform the same embedding lookup with bounds checking

Code to complete

You need to complete the missing parts in both embedding kernels:

fn embedding_kernel_2d[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[mut=True, dtype, out_layout],
    indices: LayoutTensor[mut=True, DType.int32, indices_layout],
    weights: LayoutTensor[mut=True, dtype, weights_layout],
):
    """
    2D grid non-coalesced embedding kernel.

    Non-optimal approach for comparison:
    - 2D grid: (batch*seq, embed_dim)
    - More complex indexing
    - Potentially worse memory access patterns
    """

    # 2D grid indexing
    batch_seq_idx = block_idx.x * block_dim.x + thread_idx.x
    embed_idx = block_idx.y * block_dim.y + thread_idx.y
    total_positions = batch_size * seq_len

    if batch_seq_idx >= total_positions or embed_idx >= embed_dim:
        return

    # Convert to (batch, seq) coordinates
    # FILL IN 2 lines

    # Get token index
    # FILL IN 1 line

    # Assignment with 2D grid pattern
    # FILL IN 4 lines


View full file: problems/p19/op/embedding.mojo

Tips
  • Use both X and Y thread coordinates: batch_seq_idx = block_idx.x * block_dim.x + thread_idx.x
  • And: embed_idx = block_idx.y * block_dim.y + thread_idx.y
  • Convert batch_seq_idx to separate batch and sequence indices: batch_idx = batch_seq_idx // seq_len
  • Remember to check bounds for both dimensions: if batch_seq_idx >= total_positions or embed_idx >= embed_dim
  • The token lookup is the same as 1D, but you’re only handling one embedding dimension per thread
  • This kernel processes one embedding dimension per thread instead of entire vectors

Custom ops registration

The kernels are wrapped in PyTorch custom operations for easy integration. The registration pattern is the same as MAX custom ops explained in Understanding MAX Graph custom ops:

1D coalesced operation

This operation registers the optimized 1D embedding kernel as "embedding":

import compiler
from runtime.asyncrt import DeviceContextPtr
from tensor import InputTensor, OutputTensor
from memory import UnsafePointer
from gpu.host import DeviceBuffer


@compiler.register("embedding")
struct EmbeddingCustomOp:
    @staticmethod
    fn execute[
        target: StaticString,
        batch_size: Int,
        seq_len: Int,
        vocab_size: Int,
        embed_dim: Int,
    ](
        output: OutputTensor[
            dtype = DType.float32, rank=3
        ],  # [batch_size, seq_len, embed_dim]
        indices: InputTensor[
            dtype = DType.int32, rank=2
        ],  # [batch_size, seq_len]
        weights: InputTensor[
            dtype = output.dtype, rank=2
        ],  # [vocab_size, embed_dim]
        ctx: DeviceContextPtr,
    ) raises:
        output_tensor = output.to_layout_tensor()
        indices_tensor = indices.to_layout_tensor()
        weights_tensor = weights.to_layout_tensor()

        alias indices_layout = indices_tensor.layout
        alias weights_layout = weights_tensor.layout
        alias out_layout = output_tensor.layout

        @parameter
        if target == "gpu":
            gpu_ctx = ctx.get_device_context()

            # Zero out output tensor
            gpu_ctx.enqueue_memset(
                DeviceBuffer[output.dtype](
                    gpu_ctx,
                    rebind[UnsafePointer[Scalar[output.dtype]]](
                        output_tensor.ptr
                    ),
                    batch_size * seq_len * embed_dim,
                    owning=False,
                ),
                0,
            )

            # Calculate 1D grid dimensions (matching kernel's flat indexing)
            total_elements = batch_size * seq_len * embed_dim
            blocks = max(1, ceildiv(total_elements, THREADS_PER_BLOCK))

            # Compile and launch optimized kernel
            compiled_kernel = gpu_ctx.compile_function[
                embedding_kernel_coalesced[
                    indices_layout,
                    weights_layout,
                    out_layout,
                    batch_size,
                    seq_len,
                    vocab_size,
                    embed_dim,
                    output.dtype,
                ]
            ]()

            gpu_ctx.enqueue_function(
                compiled_kernel,
                output_tensor,
                indices_tensor,
                weights_tensor,
                grid_dim=(blocks,),
                block_dim=(THREADS_PER_BLOCK,),
            )

        elif target == "cpu":
            for batch in range(batch_size):
                for seq in range(seq_len):
                    token_idx_val = Int(indices_tensor[batch, seq])
                    if token_idx_val >= 0 and token_idx_val < vocab_size:
                        for emb in range(embed_dim):
                            output_tensor[batch, seq, emb] = weights_tensor[
                                token_idx_val, emb
                            ]
        else:
            raise Error("Unsupported target: " + target)


Key aspects of this registration:

  • Simple grid configuration: Uses a straightforward 1D grid with ceildiv(total_elements, THREADS_PER_BLOCK) blocks
  • Memory optimization: Single enqueue_memset call to zero the output buffer efficiently
  • Compile-time parameters: All tensor dimensions passed as compile-time parameters for optimal performance
  • Device abstraction: Handles both GPU execution and CPU fallback seamlessly

2D non-coalesced operation

This operation registers the comparison 2D embedding kernel as "embedding_2d":

@compiler.register("embedding_2d")
struct Embedding2DCustomOp:
    @staticmethod
    fn execute[
        target: StaticString,
        batch_size: Int,
        seq_len: Int,
        vocab_size: Int,
        embed_dim: Int,
    ](
        output: OutputTensor[
            dtype = DType.float32, rank=3
        ],  # [batch_size, seq_len, embed_dim]
        indices: InputTensor[
            dtype = DType.int32, rank=2
        ],  # [batch_size, seq_len]
        weights: InputTensor[
            dtype = output.dtype, rank=2
        ],  # [vocab_size, embed_dim]
        ctx: DeviceContextPtr,
    ) raises:
        output_tensor = output.to_layout_tensor()
        indices_tensor = indices.to_layout_tensor()
        weights_tensor = weights.to_layout_tensor()

        alias indices_layout = indices_tensor.layout
        alias weights_layout = weights_tensor.layout
        alias out_layout = output_tensor.layout

        @parameter
        if target == "gpu":
            gpu_ctx = ctx.get_device_context()

            # Zero out output tensor
            gpu_ctx.enqueue_memset(
                DeviceBuffer[output.dtype](
                    gpu_ctx,
                    rebind[UnsafePointer[Scalar[output.dtype]]](
                        output_tensor.ptr
                    ),
                    batch_size * seq_len * embed_dim,
                    owning=False,
                ),
                0,
            )

            # Calculate 2D grid dimensions for non-coalesced access
            total_positions = batch_size * seq_len
            alias BLOCK_X = 16  # batch*seq dimension
            alias BLOCK_Y = 16  # embed dimension
            blocks_x = max(1, ceildiv(total_positions, BLOCK_X))
            blocks_y = max(1, ceildiv(embed_dim, BLOCK_Y))

            # Compile and launch 2D kernel
            compiled_kernel = gpu_ctx.compile_function[
                embedding_kernel_2d[
                    indices_layout,
                    weights_layout,
                    out_layout,
                    batch_size,
                    seq_len,
                    vocab_size,
                    embed_dim,
                    output.dtype,
                ]
            ]()

            gpu_ctx.enqueue_function(
                compiled_kernel,
                output_tensor,
                indices_tensor,
                weights_tensor,
                grid_dim=(blocks_x, blocks_y),
                block_dim=(BLOCK_X, BLOCK_Y),
            )

        elif target == "cpu":
            # Same CPU fallback as 1D version
            for batch in range(batch_size):
                for seq in range(seq_len):
                    token_idx_val = Int(indices_tensor[batch, seq])
                    if token_idx_val >= 0 and token_idx_val < vocab_size:
                        for emb in range(embed_dim):
                            output_tensor[batch, seq, emb] = weights_tensor[
                                token_idx_val, emb
                            ]
        else:
            raise Error("Unsupported target: " + target)


Key differences from the 1D operation:

  • Complex grid configuration: Uses a 2D grid with separate calculations for blocks_x and blocks_y
  • Fixed block dimensions: Hard-coded BLOCK_X = 16 and BLOCK_Y = 16 for 2D thread organization
  • Same memory management: Identical memory initialization and CPU fallback logic
  • Different kernel call: Passes 2D grid dimensions (blocks_x, blocks_y) and block dimensions (BLOCK_X, BLOCK_Y)

Common wrapper functionality

Both custom operations provide essential infrastructure:

  1. Memory management:

    • Zero-initialization of output tensors with enqueue_memset
    • Proper buffer creation and memory layout handling
    • Automatic cleanup and resource management
  2. Device abstraction:

    • GPU execution with optimized kernels
    • CPU fallback for compatibility and debugging
    • Consistent interface regardless of execution target
  3. Parameter passing:

    • Compile-time tensor dimensions for kernel optimization
    • Runtime tensor data through layout tensor conversion
    • Type-safe parameter validation
  4. Grid configuration:

    • Automatic calculation of optimal grid dimensions
    • Different strategies optimized for each kernel’s access pattern
    • Proper block dimension management

Integration with PyTorch

These registered operations can be called from Python using the CustomOpLibrary:

# Load the custom operations
ops = CustomOpLibrary(mojo_kernels)

# Call the 1D coalesced version
result_1d = ops.embedding[{"batch_size": B, "seq_len": L, "vocab_size": V, "embed_dim": E}](
    indices, weights
)

# Call the 2D non-coalesced version
result_2d = ops.embedding_2d[{"batch_size": B, "seq_len": L, "vocab_size": V, "embed_dim": E}](
    indices, weights
)

The power of this approach is that the same kernel implementations can be used across different Python frameworks while maintaining optimal performance characteristics.

Run the code

You can run the puzzle with:

uv run poe p19
pixi run p19

When successful, you should see output similar to:

Puzzle 19: Mojo Embedding Kernel Comparison
======================================================================
Configuration: B=8, L=512, V=10000, E=512
------------------------------------------------------------

Testing Correctness...
   1D Coalesced - Max difference: 1.19e-07
   2D Non-coalesced - Max difference: 1.19e-07
   ✅ Both implementations CORRECT

Benchmarking Mojo Kernels...

Performance Results:
   1D Coalesced:     2.145 ms
   2D Non-coalesced: 3.867 ms
   1D is 1.80x faster than 2D

Key Learning Points:
• Compare different GPU kernel implementations
• 1D vs 2D grid patterns have different memory access
• Coalesced memory access should be faster
• Grid configuration affects GPU utilization

Solution

The solution involves implementing the coordinate transformations and memory operations for both kernels:

1D Coalesced Kernel

fn embedding_kernel_coalesced[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[mut=True, dtype, out_layout],
    indices: LayoutTensor[mut=True, DType.int32, indices_layout],
    weights: LayoutTensor[mut=True, dtype, weights_layout],
):
    """
    Memory-coalescing focused embedding kernel.

    Key insight: The bottleneck is memory access patterns, not computation.
    - Each thread handles one (batch, seq, embed) position
    - Simple 1D grid for maximum simplicity and correctness
    - Focus on getting memory access right first
    """

    # Simple 1D indexing - each thread = one output element
    global_idx = block_idx.x * block_dim.x + thread_idx.x
    total_elements = batch_size * seq_len * embed_dim

    if global_idx >= total_elements:
        return

    # Convert to (batch, seq, embed) coordinates
    batch_idx = global_idx // (seq_len * embed_dim)
    remaining = global_idx % (seq_len * embed_dim)
    seq_idx = remaining // embed_dim
    embed_idx = remaining % embed_dim

    # Get token index
    token_idx_val = Int(indices[batch_idx, seq_idx])

    # Simple, correct assignment
    if token_idx_val >= 0 and token_idx_val < vocab_size:
        output[batch_idx, seq_idx, embed_idx] = weights[
            token_idx_val, embed_idx
        ]
    else:
        output[batch_idx, seq_idx, embed_idx] = 0


2D Non-Coalesced Kernel

fn embedding_kernel_2d[
    indices_layout: Layout,
    weights_layout: Layout,
    out_layout: Layout,
    batch_size: Int,
    seq_len: Int,
    vocab_size: Int,
    embed_dim: Int,
    dtype: DType = DType.float32,
](
    output: LayoutTensor[mut=True, dtype, out_layout],
    indices: LayoutTensor[mut=True, DType.int32, indices_layout],
    weights: LayoutTensor[mut=True, dtype, weights_layout],
):
    """
    2D grid non-coalesced embedding kernel.

    Non-optimal approach for comparison:
    - 2D grid: (batch*seq, embed_dim)
    - More complex indexing
    - Potentially worse memory access patterns
    """

    # 2D grid indexing
    batch_seq_idx = block_idx.x * block_dim.x + thread_idx.x
    embed_idx = block_idx.y * block_dim.y + thread_idx.y

    total_positions = batch_size * seq_len

    # Bounds check
    if batch_seq_idx >= total_positions or embed_idx >= embed_dim:
        return

    # Convert to (batch, seq) coordinates
    batch_idx = batch_seq_idx // seq_len
    seq_idx = batch_seq_idx % seq_len

    # Get token index
    token_idx_val = Int(indices[batch_idx, seq_idx])

    # Assignment with 2D grid pattern
    if token_idx_val >= 0 and token_idx_val < vocab_size:
        output[batch_idx, seq_idx, embed_idx] = weights[
            token_idx_val, embed_idx
        ]
    else:
        output[batch_idx, seq_idx, embed_idx] = 0


Both solutions implement the same embedding lookup logic but with different thread organizations:

Key differences

  1. Thread mapping:

    • 1D kernel: One thread per output element, simple flat indexing
    • 2D kernel: 2D grid mapping to (batch×seq, embed_dim) coordinates
  2. Memory access patterns:

    • 1D kernel: Consecutive threads access consecutive embedding dimensions → coalesced
    • 2D kernel: Thread access pattern depends on block configuration → potentially non-coalesced
  3. Indexing complexity:

    • 1D kernel: Single division/modulo chain to get 3D coordinates
    • 2D kernel: Separate X/Y coordinate calculations

Performance implications

The 1D kernel typically performs better because:

  • Memory coalescing: Consecutive threads access consecutive memory addresses
  • Simple indexing: Lower computational overhead for coordinate calculations
  • Better cache utilization: Predictable memory access patterns

The 2D kernel may perform worse due to:

  • Scattered memory accesses: Threads within a warp may access different embedding vectors
  • Complex grid configuration: 16×16 blocks may not align optimally with memory layout
  • Warp divergence: Different threads may follow different execution paths

Key concepts

Concept1D Coalesced2D Non-coalesced
Thread organization1D flat indexing2D grid (batch×seq, embed)
Memory accessConsecutive addressesPotentially scattered
Grid configurationSimple: [total_elements // 256]Complex: [batch×seq // 16, embed // 16]
PerformanceOptimized for memory bandwidthSuboptimal memory pattern
Use caseProduction kernelsEducational comparison

The core lesson: memory coalescing can lead to 2-3x performance differences for memory-bound operations like embeddings.