Puzzle 17: Attention Op
Overview
In this puzzle, we’ll implement the attention mechanism as a custom MAX Graph operation. Attention is a fundamental building block of modern neural networks, poplularized particularly transformers, that allows models to focus on relevant parts of the input when making predictions.
Mathematically, the attention function is defined as:
$$\Large \text{Attention}(Q, K, V) = \text{softmax}(Q \cdot K^T) \cdot V$$
Where:
- \(Q\) is the query vector of shape \((d,)\) - represents what we’re looking for
- \(K\) is the key matrix of shape \((\text{seq_len}, d)\) - represents what’s available to match against
- \(V\) is the value matrix of shape \((\text{seq_len}, d)\) - represents the information to retrieve
- The output is a weighted combination vector of shape \((d,)\)
The computation involves three main steps:
- Attention Scores: Compute \(Q \cdot K^T\) to measure how well the query matches each key vector
- Attention Weights: Apply softmax to convert scores into a probability distribution (weights sum to 1)
- Weighted Sum: Combine value vectors using attention weights to produce the final output
Understanding attention: a step-by-step breakdown
Think of attention as a smart lookup mechanism. Given a query (what you’re looking for), attention finds the most relevant information from a collection of key-value pairs:
-
Step 1 - Similarity Matching: Compare your query \(Q\) against all keys \(K\) to get similarity scores
- Compute \(Q \cdot K^T\) where each score measures how well \(Q\) matches each key vector
- Higher scores = better matches
-
Step 2 - Probability Distribution: Convert raw scores into normalized weights
- Apply softmax to ensure all weights sum to 1.0
- This creates a probability distribution over which values to focus on
-
Step 3 - Weighted Retrieval: Combine values using the attention weights
- Multiply each value vector by its corresponding weight
- Sum everything up to get the final output
Real-world analogy: Imagine searching a library. Your query is what you want to find, the book titles are keys, and the book contents are values. Attention computes how relevant each book is to your query, then gives you a summary weighted by relevance.
Visual computation flow
Input: Q(16,) K(16,16) V(16,16)
↓ ↓ ↓
Step 1: Q(1,16) @ K^T(16,16) → Scores(1,16)
↓
Step 2: softmax(Scores) → Weights(1,16) [sum = 1.0]
↓
Step 3: Weights(1,16) @ V(16,16) → Output(1,16) → reshape → Output(16,)
Key insight: We reshape the query vector \(Q\) from shape \((16,)\) to \((1,16)\) so we can use matrix multiplication instead of manual dot products. This allows us to leverage the highly optimized tiled matmul kernel from Puzzle 14!
Our GPU implementation reuses and combines optimized kernels from previous puzzles:
- Tiled matrix multiplication from Puzzle 14 for efficient \(Q \cdot K^T\) and \(\text{weights} \cdot V\) operations
- Shared memory transpose for computing \(K^T\) efficiently
- Parallel softmax from Puzzle 16 for numerically stable attention weight computation
🔄 Kernel Reuse Strategy: This puzzle demonstrates how to build complex operations by combining proven, optimized kernels from previous puzzles. Rather than writing everything from scratch, we leverage the
matmul_idiomatic_tiled
from Puzzle 14 andsoftmax_kernel
from Puzzle 16, showcasing the power of modular GPU kernel design.
Key concepts
- Vector attention mechanism for sequence processing
- Kernel reuse: Leveraging proven implementations from Puzzle 14 and Puzzle 16
- Efficient matrix multiplication using shared memory tiling
- Memory-optimized tensor reshaping to minimize buffer allocation
- Integration of multiple optimized kernels into a single operation
- Custom MAX Graph operation with multi-input support
- CPU fallback implementation for compatibility
Configuration
- Sequence length: \(\text{SEQ_LEN} = 16\) - number of key/value vectors in our sequence
- Model dimension: \(\text{D} = 16\) - dimensionality of each vector (query, keys, values)
- Threads per block: \(\text{TPB} = 16\) - matches SEQ_LEN for optimal softmax performance
- Grid dimensions: Computed dynamically to handle different matrix sizes efficiently
- Shared memory: Utilized in transpose, matmul, and softmax kernels for performance
Layout configuration:
- Query tensor:
Layout.row_major(d)
- Key tensor:
Layout.row_major(seq_len, d)
- Value tensor:
Layout.row_major(seq_len, d)
- Output tensor:
Layout.row_major(d)
- Custom op parameters:
{"seq_len": seq_len, "d": d, "dtype": dtype}
Key aspects of this puzzle include:
- Multi-kernel orchestration: Combining transpose, matmul, and softmax operations
- Memory optimization: Using reshape operations and buffer reuse to minimize allocations
- Numerical stability: Leveraging the proven softmax implementation from Puzzle 16
- Performance optimization: Using tiled algorithms from Puzzle 14 for all matrix operations
- Multi-input operations: Handling three input tensors (Q, K, V) in a single custom op
Our attention custom operation will:
- Accept query, key, and value tensors from Python
- Process them efficiently on GPU using optimized kernels
- Return the attention-weighted output vector
- Match the results of NumPy reference implementation
Code to complete
To complete this puzzle, we’ll leverage the tiled matmul kernel from Puzzle 14 and the softmax kernel from Puzzle 16. You only need to implement the transpose kernel in the Mojo file using shared memory.
1. Implement the transpose kernel
fn transpose_kernel[
layout_in: Layout, # Layout for input matrix (seq_len, d)
layout_out: Layout, # Layout for output matrix (d, seq_len)
rows: Int,
cols: Int,
dtype: DType = DType.float32,
](
out: LayoutTensor[mut=True, dtype, layout_out, MutableAnyOrigin],
inp: LayoutTensor[mut=False, dtype, layout_in, MutableAnyOrigin],
):
# FILL ME IN (roughly 18 lines)
...
View full file: problems/p17/op/attention.mojo
Tips
Transpose Kernel Implementation Guide:
-
Shared Memory Setup: Use
tb[dtype]().row_major[TPB, TPB]().shared().alloc()
to create a TPB×TPB shared memory tile for efficient data exchange between threads -
Thread Indexing: Map threads to matrix elements:
local_row = thread_idx.y
,local_col = thread_idx.x
(position within the block)global_row = block_idx.y * TPB + local_row
(position in the full matrix)
-
Two-Phase Operation:
- Phase 1: Load data from global memory into shared memory with normal indexing
- Phase 2: Store data from shared memory to global memory with swapped indexing
-
Critical Synchronization: Call
barrier()
between loading and storing to ensure all threads have finished loading before any thread starts storing -
Transpose Magic: The transpose happens through swapped indexing:
shared_tile[local_col, local_row]
instead ofshared_tile[local_row, local_col]
-
Boundary Handling: Check bounds when accessing global memory to avoid out-of-bounds reads/writes for matrices that don’t perfectly divide by TPB
-
Memory Coalescing: This pattern ensures both reads and writes are coalesced for optimal memory bandwidth utilization
2. Orchestrate the attention
var gpu_ctx = rebind[DeviceContext](ctx[])
# Define layouts for matrix multiplication
# Q reshaped to (1, d)
alias layout_q_2d = Layout.row_major(1, d)
# K^T is (d, seq_len)
alias layout_k_t = Layout.row_major(d, seq_len)
# Scores as (1, seq_len)
alias layout_scores_2d = Layout.row_major(1, seq_len)
# Weights as (1, seq_len)
alias layout_weights_2d = Layout.row_major(1, seq_len)
# Result as (1, d)
alias layout_result_2d = Layout.row_major(1, d)
alias scores_blocks_per_grid = (
(seq_len + TPB - 1) // TPB,
(1 + TPB - 1) // TPB,
)
alias result_blocks_per_grid = (
(d + TPB - 1) // TPB,
(1 + TPB - 1) // TPB,
)
alias matmul_threads_per_block = (TPB, TPB)
alias transpose_blocks_per_grid = (
(seq_len + TPB - 1) // TPB,
(d + TPB - 1) // TPB,
)
# Allocate minimal temporary buffers - reuse same buffer for different shapes
k_t_buf = gpu_ctx.enqueue_create_buffer[dtype](
seq_len * d
) # K^T as (d, seq_len)
scores_weights_buf = gpu_ctx.enqueue_create_buffer[dtype](
seq_len
) # Reused for scores and weights
k_t = LayoutTensor[mut=True, dtype, layout_k_t, MutableAnyOrigin](
k_t_buf.unsafe_ptr()
)
# Step 1: Reshape Q from (d,) to (1, d) - no buffer needed
# FILL ME IN 1 line
# Step 2: Transpose K from (seq_len, d) to K^T (d, seq_len)
# FILL ME IN 1 function call
# Step 3: Compute attention scores using matmul: Q @ K^T = (1, d) @ (d, seq_len) -> (1, seq_len)
# GPU: Uses matrix multiplication to compute all Q · K[i] scores in parallel
# Reuse scores_weights_buf as (1, seq_len) for scores
# FILL ME IN 2 lines
# Step 4: Reshape scores from (1, seq_len) to (seq_len,) for softmax
# FILL ME IN 1 line
# Step 5: Apply softmax to get attention weights
# FILL ME IN 1 function call
# Step 6: Reshape weights from (seq_len,) to (1, seq_len) for final matmul
# FILL ME IN 1 line
# Step 7: Compute final result using matmul: weights @ V = (1, seq_len) @ (seq_len, d) -> (1, d)
# Reuse out_tensor reshaped as (1, d) for result
# FILL ME IN 2 lines
View full file: problems/p17/op/attention.mojo
Test the kernels
uv run poe p17
pixi run p17
When successful, you should see output similar to on CPU and GPU:
Input shapes: Q=(16,), K=(16, 16), V=(16, 16)
Sample Q values: [ 0.04967142 -0.01382643 0.06476886 0.15230298 -0.02341534]
Sample K[0] values: [-0.10128311 0.03142473 -0.09080241 -0.14123037 0.14656489]
Sample V[0] values: [ 0.11631638 0.00102331 -0.09815087 0.04621035 0.01990597]
================================================================================
STEP-BY-STEP VECTOR ATTENTION COMPUTATION DEBUG
================================================================================
1. INPUT SHAPES:
Q shape: (16,) (query vector)
K shape: (16, 16) (key matrix)
V shape: (16, 16) (value matrix)
Q[:5]: [ 0.04967142 -0.01382643 0.06476886 0.15230298 -0.02341534]
2. ATTENTION SCORES (K[i] · Q):
Scores shape: (16,)
Scores[:5]: [-0.03479404 -0.01563787 0.04834607 0.06764711 0.04001468]
Min: -0.061636, Max: 0.067647
Manual verification:
Q · K[0] = K[0] · Q = -0.034794 (computed: -0.034794)
Q · K[1] = K[1] · Q = -0.015638 (computed: -0.015638)
Q · K[2] = K[2] · Q = 0.048346 (computed: 0.048346)
3. SOFTMAX:
Max score: 0.067647
Attention weights shape: (16,)
Attention weights[:5]: [0.05981331 0.06097015 0.06499878 0.0662655 0.06445949]
Sum: 1.000000 (should be 1.0)
4. WEIGHTED SUM OF VALUES:
Output shape: (16,)
Output[:5]: [-0.00935538 -0.0243433 0.00306551 0.02346884 0.019306 ]
Output norm: 0.092764
Manual output[:5]: [-0.00935538 -0.0243433 0.00306551 0.02346884 0.019306 ]
Match: True
================================================================================
TESTING INDIVIDUAL OPERATIONS
================================================================================
Test 1: Vector Dot Product
a · b = 3.000000
Test 2: Matrix-Vector Multiplication
M @ v = [ 3. 7. 11.]
Test 3: Softmax
Input: [1. 2. 3. 4.]
Softmax: [0.0320586 0.08714432 0.2368828 0.6439143 ]
Sum: 1.000000
================================================================================
TESTING FULL ATTENTION
================================================================================
Compiling attention graph on Device(type=cpu,id=0)
Executing attention on Device(type=cpu,id=0)
====================================================================================================
CPU attention output[:5]: [-0.00935538 -0.02434331 0.00306551 0.02346884 0.019306 ]
CPU matches NumPy: True
Compiling attention graph on Device(type=gpu,id=0)
Executing attention on Device(type=gpu,id=0)
====================================================================================================
GPU attention output[:5]: [-0.00935538 -0.0243433 0.00306551 0.02346884 0.019306 ]
Expected output[:5]: [-0.00935538 -0.0243433 0.00306551 0.02346884 0.019306 ]
GPU matches NumPy: True
================================================================================
FINAL VERIFICATION
================================================================================
✓ CPU implementation PASSED
✓ GPU implementation PASSED
Output vector norms:
CPU: 0.092764
GPU: 0.092764
Expected: 0.092764
This indicates that your custom MAX Graph operation correctly implements the attention algorithm and produces results matching the NumPy reference implementation.
Solution
To solve this puzzle, we need to implement the transpose kernel in Mojo and complete the Python graph definition for our attention custom operation. This puzzle builds upon concepts from previous puzzles, combining tiled matrix multiplication from Puzzle 14 and softmax from Puzzle 16 into a complete attention mechanism.
Reused kernels
Our implementation directly incorporates these proven kernels:
matmul_idiomatic_tiled
from Puzzle 14 - Powers both \(Q \times K^T\) and \(\text{weights} \times V\) operationssoftmax_kernel
from Puzzle 16 - Provides numerically stable attention weight computation
This exemplifies modular GPU architecture: complex neural network operations built by orchestrating proven, optimized components rather than monolithic implementations.
The attention operation follows the canonical mathematical definition:
$$\Large \text{Attention}(Q, K, V) = \text{softmax}(Q \cdot K^T) \cdot V$$
Breaking down the math:
- \(Q \cdot K^T\): Query-key similarity scores of shape: \((1, \text{seq_len})\)
- \(\text{softmax}(\cdot)\): Normalize scores to probabilities of shape: \((1, \text{seq_len})\)
- \(\text{weights} \cdot V\): Weighted combination of values of shape: \((1, d)\)
This involves several computational steps that we optimize using GPU kernels from previous puzzles.
1. Transpose kernel implementation:
fn transpose_kernel[
layout_in: Layout, # Layout for input matrix (seq_len, d)
layout_out: Layout, # Layout for output matrix (d, seq_len)
rows: Int,
cols: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[mut=True, dtype, layout_out, MutableAnyOrigin],
inp: LayoutTensor[mut=False, dtype, layout_in, MutableAnyOrigin],
):
"""Transpose matrix using shared memory tiling for coalesced access."""
shared_tile = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
local_row = thread_idx.y
local_col = thread_idx.x
global_row = block_idx.y * TPB + local_row
global_col = block_idx.x * TPB + local_col
if global_row < rows and global_col < cols:
shared_tile[local_row, local_col] = inp[global_row, global_col]
else:
shared_tile[local_row, local_col] = 0.0
barrier()
out_row = block_idx.x * TPB + local_row
out_col = block_idx.y * TPB + local_col
# Store data from shared memory to global memory (coalesced write)
# Note: we transpose the shared memory access pattern
if out_row < cols and out_col < rows:
output[out_row, out_col] = shared_tile[local_col, local_row]
The transpose kernel uses shared memory tiling to achieve coalesced memory access patterns. Key implementation details:
Critical transpose pattern
# Load with normal indexing
shared_tile[local_row, local_col] = inp[global_row, global_col]
barrier()
# Store with swapped indexing for transpose
out[out_row, out_col] = shared_tile[local_col, local_row]
The transpose happens through swapped indexing in shared memory access ([local_col, local_row]
instead of [local_row, local_col]
) and swapped block coordinates for output positioning. This ensures both reads and writes remain coalesced while achieving the transpose operation.
2. GPU kernel orchestration:
# Step 1: Reshape Q from (d,) to (1, d) - no buffer needed
q_2d = q_tensor.reshape[layout_q_2d]()
# Step 2: Transpose K from (seq_len, d) to K^T (d, seq_len)
gpu_ctx.enqueue_function[
transpose_kernel[layout_k, layout_k_t, seq_len, d, dtype]
](
k_t,
k_tensor,
grid_dim=transpose_blocks_per_grid,
block_dim=matmul_threads_per_block,
)
# Step 3: Compute attention scores using matmul: Q @ K^T = (1, d) @ (d, seq_len) -> (1, seq_len)
# This computes Q · K^T[i] = Q · K[i] for each column i of K^T (which is row i of K)
# Reuse scores_weights_buf as (1, seq_len) for scores
scores_2d = LayoutTensor[
mut=True, dtype, layout_scores_2d, MutableAnyOrigin
](scores_weights_buf.unsafe_ptr())
gpu_ctx.enqueue_function[
matmul_idiomatic_tiled[layout_q_2d, 1, seq_len, d, dtype]
](
scores_2d,
q_2d,
k_t,
grid_dim=scores_blocks_per_grid,
block_dim=matmul_threads_per_block,
)
# Step 4: Reshape scores from (1, seq_len) to (seq_len,) for softmax
weights = scores_2d.reshape[layout_scores]()
# Step 5: Apply softmax to get attention weights
gpu_ctx.enqueue_function[
softmax_kernel[layout_scores, seq_len, dtype]
](
weights,
weights,
grid_dim=(1, 1),
block_dim=(seq_len, 1),
)
# Step 6: Reshape weights from (seq_len,) to (1, seq_len) for final matmul
weights_2d = weights.reshape[layout_weights_2d]()
# Step 7: Compute final result using matmul: weights @ V = (1, seq_len) @ (seq_len, d) -> (1, d)
# Reuse out_tensor reshaped as (1, d) for result
result_2d = output_tensor.reshape[layout_result_2d]()
gpu_ctx.enqueue_function[
matmul_idiomatic_tiled[layout_weights_2d, 1, d, seq_len, dtype]
](
result_2d,
weights_2d,
v_tensor,
grid_dim=result_blocks_per_grid,
block_dim=matmul_threads_per_block,
)
The GPU orchestration demonstrates sophisticated kernel chaining and zero-copy memory optimization:
Advanced memory optimization strategies
# Zero-copy reshaping - no data movement, just reinterpret tensor shape
q_2d = q_tensor.reshape[layout_q_2d]()
# Aggressive buffer reuse - same memory, different interpretations
weights = scores_2d.reshape[layout_scores]()
The implementation achieves maximum memory efficiency through:
- Zero-copy reshaping: Reinterpreting tensor shapes without moving data in memory
- Intelligent buffer reuse: The same
scores_weights_buf
serves dual purposes as both scores \((1,\text{seq\_len})\) and weights \((\text{seq\_len},)\) - Minimal allocations: Only 2 temporary buffers power the entire attention operation
- Memory coalescing: All operations maintain optimal memory access patterns
Strategic kernel reuse pattern
- Steps 3 & 7: Both leverage
matmul_idiomatic_tiled
from Puzzle 14- Step 3: \(Q \times K^T\) → attention scores computation \((1,d) \times (d,\text{seq_len}) \rightarrow (1,\text{seq_len})\)
- Step 7: \(\text{weights} \times V\) → final weighted output \((1,\text{seq_len}) \times (\text{seq_len},d) \rightarrow (1,d)\)
- Step 5: Employs
softmax_kernel
from Puzzle 16- Converts raw scores into normalized probability distribution
- Ensures numerical stability through max subtraction and parallel reduction
- Guarantees \(\sum_{i} \text{weights}[i] = 1.0\)
This exemplifies modular GPU architecture: complex neural network operations built by orchestrating proven, optimized kernels rather than monolithic implementations!
Key implementation insights
Memory optimization strategy
The implementation achieves minimal memory allocation through aggressive buffer reuse:
# Only 2 temporary buffers needed for the entire operation
k_t_buf = gpu_ctx.enqueue_create_buffer[dtype](seq_len * d)
scores_weights_buf = gpu_ctx.enqueue_create_buffer[dtype](seq_len)
Key optimization insights:
- The same
scores_weights_buf
is reused for both attention scores and weights through reshape operations - Zero-copy tensor reshaping eliminates unnecessary data movement
Kernel reuse architecture
This puzzle showcases modular kernel design by combining three specialized kernels:
matmul_idiomatic_tiled
(used twice) - Powers both \(Q \times K^T\) and \(\text{weights} \times V\) operationssoftmax_kernel
- Provides numerically stable attention weight computation with parallel reductiontranspose_kernel
- Enables efficient \(K^T\) computation with coalesced memory access
Architectural benefits:
- Composability: Complex operations built from proven components
- Maintainability: Each kernel has a single, well-defined responsibility
- Performance: Leverages highly optimized implementations from previous puzzles
- Scalability: Modular design enables easy extension to larger attention mechanisms
The implementation demonstrates that sophisticated neural network operations can be built by orchestrating simpler, well-tested GPU kernels rather than writing monolithic implementations.