Tiled Matrix Multiplication
Overview
Implement a kernel that multiplies square matrices \(A\) and \(B\) using tiled matrix multiplication with LayoutTensor. This approach handles large matrices by processing them in smaller chunks (tiles).
Key concepts
- Matrix tiling with LayoutTensor for efficient computation
- Multi-block coordination with proper layouts
- Efficient shared memory usage through TensorBuilder
- Boundary handling for tiles with LayoutTensor indexing
Configuration
- Matrix size: \(\text{SIZE_TILED} = 9\)
- Threads per block: \(\text{TPB} \times \text{TPB} = 3 \times 3\)
- Grid dimensions: \(3 \times 3\) blocks
- Shared memory: Two \(\text{TPB} \times \text{TPB}\) LayoutTensors per block
Layout configuration:
- Input A:
Layout.row_major(SIZE_TILED, SIZE_TILED)
- Input B:
Layout.row_major(SIZE_TILED, SIZE_TILED)
- Output:
Layout.row_major(SIZE_TILED, SIZE_TILED)
- Shared Memory: Two
TPB × TPB
LayoutTensors using TensorBuilder
Tiling strategy
Block organization
Grid Layout (3×3): Thread Layout per Block (3×3):
[B00][B01][B02] [T00 T01 T02]
[B10][B11][B12] [T10 T11 T12]
[B20][B21][B22] [T20 T21 T22]
Each block processes a tile using LayoutTensor indexing
Tile processing steps
- Calculate global and local indices for thread position
- Allocate shared memory for A and B tiles
- For each tile:
- Load tile from matrix A and B
- Compute partial products
- Accumulate results in registers
- Write final accumulated result
Memory access pattern
Matrix A (8×8) Matrix B (8×8) Matrix C (8×8)
+---+---+---+ +---+---+---+ +---+---+---+
|T00|T01|T02| ... |T00|T01|T02| ... |T00|T01|T02| ...
+---+---+---+ +---+---+---+ +---+---+---+
|T10|T11|T12| |T10|T11|T12| |T10|T11|T12|
+---+---+---+ +---+---+---+ +---+---+---+
|T20|T21|T22| |T20|T21|T22| |T20|T21|T22|
+---+---+---+ +---+---+---+ +---+---+---+
... ... ...
Tile Processing (for computing C[T11]):
1. Load tiles from A and B:
+---+ +---+
|A11| × |B11| For each phase k:
+---+ +---+ C[T11] += A[row, k] × B[k, col]
2. Tile movement:
Phase 1 Phase 2 Phase 3
A: [T10] A: [T11] A: [T12]
B: [T01] B: [T11] B: [T21]
3. Each thread (i,j) in tile computes:
C[i,j] = Σ (A[i,k] × B[k,j]) for k in tile width
Synchronization required:
* After loading tiles to shared memory
* After computing each phase
Code to complete
alias SIZE_TILED = 9
alias BLOCKS_PER_GRID_TILED = (3, 3) # each block convers 3x3 elements
alias THREADS_PER_BLOCK_TILED = (TPB, TPB)
alias layout_tiled = Layout.row_major(SIZE_TILED, SIZE_TILED)
fn matmul_tiled[
layout: Layout, size: Int
](
output: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
b: LayoutTensor[mut=False, dtype, layout],
):
local_row = thread_idx.y
local_col = thread_idx.x
tiled_row = block_idx.y * TPB + thread_idx.y
tiled_col = block_idx.x * TPB + thread_idx.x
# FILL ME IN (roughly 20 lines)
View full file: problems/p14/p14.mojo
Tips
-
Use the standard indexing convention:
local_row = thread_idx.y
andlocal_col = thread_idx.x
-
Calculate global positions:
global_row = block_idx.y * TPB + local_row
and
global_col = block_idx.x * TPB + local_col
Understanding the global indexing formula:
-
Each block processes a
TPB × TPB
tile of the matrix -
block_idx.y
tells us which row of blocks we’re in (0, 1, 2…) -
block_idx.y * TPB
gives us the starting row of our block’s tile -
local_row
(0 to TPB-1) is our thread’s offset within the block -
Adding them gives our thread’s actual row in the full matrix
Example with TPB=3:
Block Layout: Global Matrix (9×9): [B00][B01][B02] [0 1 2 | 3 4 5 | 6 7 8] [B10][B11][B12] → [9 A B | C D E | F G H] [B20][B21][B22] [I J K | L M N | O P Q] —————————————————————— [R S T | U V W | X Y Z] [a b c | d e f | g h i] [j k l | m n o | p q r] —————————————————————— [s t u | v w x | y z α] [β γ δ | ε ζ η | θ ι κ] [λ μ ν | ξ ο π | ρ σ τ] Thread(1,2) in Block(1,0): - block_idx.y = 1, local_row = 1 - global_row = 1 * 3 + 1 = 4 - This thread handles row 4 of the matrix
-
-
Allocate shared memory (now pre-initialized with
.fill(0)
) -
With 9×9 perfect tiling, no bounds checking needed!
-
Accumulate results across tiles with proper synchronization
Running the code
To test your solution, run the following command in your terminal:
uv run poe p14 --tiled
pixi run p14 --tiled
Your output will look like this if the puzzle isn’t solved yet:
out: HostBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([3672.0, 3744.0, 3816.0, 3888.0, 3960.0, 4032.0, 4104.0, 4176.0, 4248.0, 9504.0, 9738.0, 9972.0, 10206.0, 10440.0, 10674.0, 10908.0, 11142.0, 11376.0, 15336.0, 15732.0, 16128.0, 16524.0, 16920.0, 17316.0, 17712.0, 18108.0, 18504.0, 21168.0, 21726.0, 22284.0, 22842.0, 23400.0, 23958.0, 24516.0, 25074.0, 25632.0, 27000.0, 27720.0, 28440.0, 29160.0, 29880.0, 30600.0, 31320.0, 32040.0, 32760.0, 32832.0, 33714.0, 34596.0, 35478.0, 36360.0, 37242.0, 38124.0, 39006.0, 39888.0, 38664.0, 39708.0, 40752.0, 41796.0, 42840.0, 43884.0, 44928.0, 45972.0, 47016.0, 44496.0, 45702.0, 46908.0, 48114.0, 49320.0, 50526.0, 51732.0, 52938.0, 54144.0, 50328.0, 51696.0, 53064.0, 54432.0, 55800.0, 57168.0, 58536.0, 59904.0, 61272.0])
Solution: Manual tiling
fn matmul_tiled[
layout: Layout, size: Int
](
output: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
b: LayoutTensor[mut=False, dtype, layout],
):
local_row = thread_idx.y
local_col = thread_idx.x
tiled_row = block_idx.y * TPB + local_row
tiled_col = block_idx.x * TPB + local_col
a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
var acc: output.element_type = 0
# Iterate over tiles to compute matrix product
@parameter
for tile in range((size + TPB - 1) // TPB):
# Load A tile - global row stays the same, col determined by tile
if tiled_row < size and (tile * TPB + local_col) < size:
a_shared[local_row, local_col] = a[
tiled_row, tile * TPB + local_col
]
# Load B tile - row determined by tile, global col stays the same
if (tile * TPB + local_row) < size and tiled_col < size:
b_shared[local_row, local_col] = b[
tile * TPB + local_row, tiled_col
]
barrier()
# Matrix multiplication within the tile
if tiled_row < size and tiled_col < size:
@parameter
for k in range(min(TPB, size - tile * TPB)):
acc += a_shared[local_row, k] * b_shared[k, local_col]
barrier()
# Write out final result
if tiled_row < size and tiled_col < size:
output[tiled_row, tiled_col] = acc
The tiled matrix multiplication implementation demonstrates efficient handling of matrices \((9 \times 9)\) using small tiles \((3 \times 3)\). Here’s how it works:
-
Shared memory allocation
Input matrices (9×9) - Perfect fit for (3×3) tiling: A = [0 1 2 3 4 5 6 7 8 ] B = [0 2 4 6 8 10 12 14 16] [9 10 11 12 13 14 15 16 17] [18 20 22 24 26 28 30 32 34] [18 19 20 21 22 23 24 25 26] [36 38 40 42 44 46 48 50 52] [27 28 29 30 31 32 33 34 35] [54 56 58 60 62 64 66 68 70] [36 37 38 39 40 41 42 43 44] [72 74 76 78 80 82 84 86 88] [45 46 47 48 49 50 51 52 53] [90 92 94 96 98 100 102 104 106] [54 55 56 57 58 59 60 61 62] [108 110 112 114 116 118 120 122 124] [63 64 65 66 67 68 69 70 71] [126 128 130 132 134 136 138 140 142] [72 73 74 75 76 77 78 79 80] [144 146 148 150 152 154 156 158 160] Shared memory per block (3×3): a_shared[TPB, TPB] b_shared[TPB, TPB]
-
Tile processing loop
Number of tiles = 9 // 3 = 3 tiles (perfect division!) For each tile: 1. Load tile from A and B 2. Compute partial products 3. Accumulate in register
-
Memory loading pattern
- With perfect \((9 \times 9)\) tiling, bounds check is technically unnecessary but included for defensive programming and consistency with other matrix sizes.
# Load A tile - global row stays the same, col determined by tile if tiled_row < size and (tile * TPB + local_col) < size: a_shared[local_row, local_col] = a[ tiled_row, tile * TPB + local_col ] # Load B tile - row determined by tile, global col stays the same if (tile * TPB + local_row) < size and tiled_col < size: b_shared[local_row, local_col] = b[ tile * TPB + local_row, tiled_col ]
- With perfect \((9 \times 9)\) tiling, bounds check is technically unnecessary but included for defensive programming and consistency with other matrix sizes.
-
Computation within tile
for k in range(min(TPB, size - tile * TPB)): acc += a_shared[local_row, k] * b_shared[k, local_col]
- Avoids shared memory bank conflicts:
Left: When threads from the same warp access different banks of shared memory during the same memory transaction (Bank Conflict Free (Good): Bank Conflicts (Bad): Thread0: a_shared[0,k] b_shared[k,0] Thread0: a_shared[k,0] b_shared[0,k] Thread1: a_shared[0,k] b_shared[k,1] Thread1: a_shared[k,0] b_shared[1,k] Thread2: a_shared[0,k] b_shared[k,2] Thread2: a_shared[k,0] b_shared[2,k] ↓ ↓ Parallel access to different banks Serialized access to same bank of b_shared (or broadcast for a_shared) if shared memory was col_major (transposed)
b_shared[k,threadIdx.x]
), or all threads in a warp access the same shared memory address (a_shared[0,k]
), accesses can proceed in parallel. Right: When threads from the same warp access the same bank simultaneously (bankk
inb_shared
), accesses are serialized, reducing performance. Note: Shared memory has 32 banks (on modern GPUs). Bank conflicts occur when multiple threads in a warp access different addresses in the same bank simultaneously. Broadcasts (same address) are handled efficiently.
- Avoids shared memory bank conflicts:
-
Synchronization points
barrier() after: 1. Tile loading 2. Tile computation
Key performance features:
- Processes \((9 \times 9)\) matrix using \((3 \times 3)\) tiles (perfect fit!)
- Uses shared memory for fast tile access
- Minimizes global memory transactions with coalesced memory access
- Optimized shared memory layout and access pattern to avoid shared memory bank conflicts
- Result writing:
if tiled_row < size and tiled_col < size: output[tiled_row, tiled_col] = acc
- Defensive bounds checking included for other matrix sizes and tiling strategies
- Direct assignment to output matrix
- All threads write valid results
Key optimizations
-
Layout optimization:
- Row-major layout for all tensors
- Efficient 2D indexing
-
Memory access:
- Coalesced global memory loads
- Efficient shared memory usage
-
Computation:
- Register-based accumulation i.e.
var acc: output.element_type = 0
- Compile-time loop unrolling via
@parameter
- Register-based accumulation i.e.
This implementation achieves high performance through:
- Efficient use of LayoutTensor for memory access
- Optimal tiling strategy
- Proper thread synchronization
- Careful boundary handling
Solution: Idiomatic LayoutTensor tiling
from gpu.memory import async_copy_wait_all
from layout.layout_tensor import copy_dram_to_sram_async
fn matmul_idiomatic_tiled[
layout: Layout, size: Int
](
output: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
b: LayoutTensor[mut=False, dtype, layout],
):
local_row = thread_idx.y
local_col = thread_idx.x
tiled_row = block_idx.y * TPB + local_row
tiled_col = block_idx.x * TPB + local_col
# Get the tile of the output matrix that this thread block is responsible for
out_tile = output.tile[TPB, TPB](block_idx.y, block_idx.x)
a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc().fill(0)
b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc().fill(0)
var acc: output.element_type = 0
alias load_a_layout = Layout.row_major(1, TPB)
alias load_b_layout = Layout.row_major(TPB, 1)
@parameter
for idx in range(size // TPB): # Perfect division: 9 // 3 = 3 tiles
# Get tiles from A and B matrices
a_tile = a.tile[TPB, TPB](block_idx.y, idx)
b_tile = b.tile[TPB, TPB](idx, block_idx.x)
# Asynchronously copy tiles to shared memory
copy_dram_to_sram_async[thread_layout=load_a_layout](a_shared, a_tile)
copy_dram_to_sram_async[thread_layout=load_b_layout](b_shared, b_tile)
# Wait for all async copies to complete
async_copy_wait_all()
barrier()
# Compute partial matrix multiplication for this tile
@parameter
for k in range(TPB):
acc += a_shared[local_row, k] * b_shared[k, local_col]
barrier()
# Write final result to output tile
if tiled_row < size and tiled_col < size:
out_tile[local_row, local_col] = acc
The idiomatic tiled matrix multiplication leverages Mojo’s LayoutTensor API and asynchronous memory operations for a beautifully clean implementation. With the \((9 \times 9)\) matrix size, we get perfect tiling that eliminates all boundary checks:
-
LayoutTensor tile API
out_tile = output.tile[TPB, TPB](block_idx.y, block_idx.x) a_tile = a.tile[TPB, TPB](block_idx.y, idx) b_tile = b.tile[TPB, TPB](idx, block_idx.x)
This directly expresses “get the tile at position (block_idx.y, block_idx.x)” without manual coordinate calculation. See the documentation for more details.
-
Asynchronous memory operations
copy_dram_to_sram_async[thread_layout=load_a_layout](a_shared, a_tile) copy_dram_to_sram_async[thread_layout=load_b_layout](b_shared, b_tile) async_copy_wait_all()
These operations:
- Use dedicated copy engines that bypass registers and enable compute-memory overlap via copy_dram_to_sram_async
- Use specialized thread layouts for optimal memory access patterns
- Eliminate the need for manual memory initialization
- Note: Standard GPU loads are already asynchronous; these provide better resource utilization
-
Specialized compile-time load layouts
alias load_a_layout = Layout.row_major(1, TPB) alias load_b_layout = Layout.row_major(TPB, 1)
These layouts optimize how threads cooperate during memory transfers:
load_a_layout
: Each thread loads a slice of a row (coalesced access)load_b_layout
: Each thread loads a slice of a column (transposed access)
-
Perfect tiling eliminates boundary checks
@parameter for idx in range(size // TPB): # Perfect division: 9 // 3 = 3
With \((9 \times 9)\) matrices and \((3 \times 3)\) tiles, every tile is exactly full-sized. No boundary checking needed!
-
Clean tile processing with defensive bounds checking
# Defensive bounds checking included even with perfect tiling if tiled_row < size and tiled_col < size: out_tile[local_row, local_col] = acc
With perfect \((9 \times 9)\) tiling, this bounds check is technically unnecessary but included for defensive programming and consistency with other matrix sizes.
Performance considerations
The idiomatic implementation maintains the performance benefits of tiling while providing cleaner abstractions:
- Memory locality: Exploits spatial and temporal locality through tiling
- Coalesced access: Specialized load layouts ensure coalesced memory access patterns
- Compute-memory overlap: Potential overlap through asynchronous memory operations
- Shared memory efficiency: No redundant initialization of shared memory
- Register pressure: Uses accumulation registers for optimal compute throughput
This implementation shows how high-level abstractions can express complex GPU algorithms without sacrificing performance. It’s a prime example of Mojo’s philosophy: combining high-level expressiveness with low-level performance control.
Key differences from manual tiling
Feature | Manual Tiling | Idiomatic Tiling |
---|---|---|
Memory access | Direct indexing with bounds checks | LayoutTensor tile API |
Tile loading | Explicit element-by-element copying | Dedicated copy engine bulk transfers |
Shared memory | Manual initialization (zeroing) | Managed by copy functions |
Code complexity | More verbose with explicit indexing | More concise with higher-level APIs |
Bounds checking | Multiple checks during loading and computing | Single defensive check at final write |
The idiomatic approach is not just cleaner but also potentially more performant due to the use of specialized memory layouts and asynchronous operations.