Tiled Matrix Multiplication
Overview
Implement a kernel that multiplies square matrices \(A\) and \(B\) using tiled matrix multiplication with LayoutTensor. This approach handles large matrices by processing them in smaller chunks (tiles).
Key concepts
- Matrix tiling with LayoutTensor for efficient computation
- Multi-block coordination with proper layouts
- Efficient shared memory usage through TensorBuilder
- Boundary handling for tiles with LayoutTensor indexing
Configuration
- Matrix size: \(\text{SIZE_TILED} = 8\)
- Threads per block: \(\text{TPB} \times \text{TPB} = 3 \times 3\)
- Grid dimensions: \(3 \times 3\) blocks
- Shared memory: Two \(\text{TPB} \times \text{TPB}\) LayoutTensors per block
Layout configuration:
- Input A:
Layout.row_major(SIZE_TILED, SIZE_TILED)
- Input B:
Layout.row_major(SIZE_TILED, SIZE_TILED)
- Output:
Layout.row_major(SIZE_TILED, SIZE_TILED)
- Shared Memory: Two
TPB × TPB
LayoutTensors using TensorBuilder
Tiling strategy
Block organization
Grid Layout (3×3): Thread Layout per Block (3×3):
[B00][B01][B02] [T00 T01 T02]
[B10][B11][B12] [T10 T11 T12]
[B20][B21][B22] [T20 T21 T22]
Each block processes a tile using LayoutTensor indexing
Tile processing steps
- Calculate global and local indices for thread position
- Allocate shared memory for A and B tiles
- For each tile:
- Reset shared memory
- Load tile from matrix A and B
- Compute partial products
- Accumulate results in registers
- Write final accumulated result
Memory access pattern
For each tile:
Input Tiles: Output Computation:
A[global_i, tile*TPB + j] × Result accumulator
B[tile*TPB + i, global_j] → C[global_i, global_j]
Code to complete
alias SIZE_TILED = 8
alias BLOCKS_PER_GRID_TILED = (3, 3) # each block convers 3x3 elements
alias THREADS_PER_BLOCK_TILED = (TPB, TPB)
alias layout_tiled = Layout.row_major(SIZE_TILED, SIZE_TILED)
fn matmul_tiled[
layout: Layout, size: Int
](
out: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
b: LayoutTensor[mut=False, dtype, layout],
):
local_row = thread_idx.x
local_col = thread_idx.y
global_row = block_idx.x * TPB + local_row
global_col = block_idx.y * TPB + local_col
# FILL ME IN (roughly 20 lines)
View full file: problems/p14/p14.mojo
Tips
- Calculate global thread positions from block and thread indices correctly
- Clear shared memory before loading new tiles
- Load tiles with proper bounds checking
- Accumulate results across tiles with proper synchronization
Running the code
To test your solution, run the following command in your terminal:
magic run p14 --tiled
Your output will look like this if the puzzle isn’t solved yet:
out: HostBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([2240.0, 2296.0, 2352.0, 2408.0, 2464.0, 2520.0, 2576.0, 2632.0, 5824.0, 6008.0, 6192.0, 6376.0, 6560.0, 6744.0, 6928.0, 7112.0, 9408.0, 9720.0, 10032.0, 10344.0, 10656.0, 10968.0, 11280.0, 11592.0, 12992.0, 13432.0, 13872.0, 14312.0, 14752.0, 15192.0, 15632.0, 16072.0, 16576.0, 17144.0, 17712.0, 18280.0, 18848.0, 19416.0, 19984.0, 20552.0, 20160.0, 20856.0, 21552.0, 22248.0, 22944.0, 23640.0, 24336.0, 25032.0, 23744.0, 24568.0, 25392.0, 26216.0, 27040.0, 27864.0, 28688.0, 29512.0, 27328.0, 28280.0, 29232.0, 30184.0, 31136.0, 32088.0, 33040.0, 33992.0])
Solution
fn matmul_tiled[
layout: Layout, size: Int
](
out: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
b: LayoutTensor[mut=False, dtype, layout],
):
local_i = thread_idx.x
local_j = thread_idx.y
global_i = block_idx.x * TPB + local_i
global_j = block_idx.y * TPB + local_j
a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
var acc: out.element_type = 0
# Iterate over tiles to compute matrix product
@parameter
for tile in range((size + TPB - 1) // TPB):
# Reset shared memory tiles
if local_i < TPB and local_j < TPB:
a_shared[local_i, local_j] = 0
b_shared[local_i, local_j] = 0
barrier()
# Load A tile - global row stays the same, col determined by tile
if global_i < size and (tile * TPB + local_j) < size:
a_shared[local_i, local_j] = a[
global_i, tile * TPB + local_j
]
# Load B tile - row determined by tile, global col stays the same
if (tile * TPB + local_i) < size and global_j < size:
b_shared[local_i, local_j] = b[
tile * TPB + local_i, global_j
]
barrier()
# Matrix multiplication within the tile
if global_i < size and global_j < size:
@parameter
for k in range(min(TPB, size - tile * TPB)):
acc += a_shared[local_i, k] * b_shared[k, local_j]
barrier()
# Write out final result
if global_i < size and global_j < size:
out[global_i, global_j] = acc
The tiled implementation with LayoutTensor handles matrices efficiently by processing them in blocks. Here’s a comprehensive analysis:
Implementation Architecture
-
Thread and Block Organization:
local_i = thread_idx.x local_j = thread_idx.y global_i = block_idx.x * TPB + local_i global_j = block_idx.y * TPB + local_j
- Defines row-major layout for all tensors
- Ensures consistent memory access patterns
- Enables efficient 2D indexing
-
Shared Memory Setup:
a_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc() b_shared = tb[dtype]().row_major[TPB, TPB]().shared().alloc()
- Uses TensorBuilder for structured allocation
- Maintains row-major layout for consistency
- Enables efficient tile processing
Tile Processing Pipeline
-
Tile Iteration:
@parameter for tile in range((size + TPB - 1) // TPB):
- Compile-time unrolled loop
- Handles matrix size not divisible by TPB
- Processes matrix in TPB×TPB tiles
-
Shared Memory Reset:
if local_i < TPB and local_j < TPB: a_shared[local_i, local_j] = 0 b_shared[local_i, local_j] = 0
- Clears previous tile data
- Ensures clean state for new tile
- Prevents data corruption
-
Tile Loading:
# Load A tile - global row stays the same, col determined by tile if global_i < size and (tile * TPB + local_j) < size: a_shared[local_i, local_j] = a[global_i, tile * TPB + local_j] # Load B tile - row determined by tile, global col stays the same if (tile * TPB + local_i) < size and global_j < size: b_shared[local_i, local_j] = b[tile * TPB + local_i, global_j]
- Handles boundary conditions
- Uses LayoutTensor indexing
- Maintains memory coalescing
-
Computation:
@parameter for k in range(min(TPB, size - tile * TPB)): acc += a_shared[local_i, k] * b_shared[k, local_j]
- Processes current tile
- Uses shared memory for efficiency
- Handles partial tiles correctly
Memory Access Optimization
-
Global Memory Pattern:
A[global_i, tile * TPB + local_j] → Row-major access B[tile * TPB + local_i, global_j] → Row-major access
- Maximizes memory coalescing:
When threads access consecutive memory locations (left), the GPU can combine multiple reads into a single transaction. When threads access scattered locations (right), each access requires a separate transaction, reducing performance.Coalesced Access (Good): Non-Coalesced Access (Bad): Thread0: [M0][M1][M2][M3] Thread0: [M0][ ][ ][ ] Thread1: [M4][M5][M6][M7] vs Thread1: [ ][M1][ ][ ] Thread2: [M8][M9][MA][MB] Thread2: [ ][ ][M2][ ] Thread3: [MC][MD][ME][MF] Thread3: [ ][ ][ ][M3] ↓ ↓ 1 memory transaction 4 memory transactions
- Maximizes memory coalescing:
-
Shared Memory Usage:
a_shared[local_i, k] → Row-wise access b_shared[k, local_j] → Row-wise access
- Optimized for matrix multiplication
- Reduces bank conflicts
- Enables data reuse
Synchronization and Safety
-
Barrier Points:
barrier() # After shared memory reset barrier() # After tile loading barrier() # After computation
- Ensures shared memory consistency
- Prevents race conditions
- Maintains thread cooperation
-
Boundary Handling:
if global_i < size and global_j < size: out[global_i, global_j] = acc
- Prevents out-of-bounds access
- Handles matrix edges
- Safe result writing
Key Optimizations
-
Layout Optimization:
- Row-major layout for all tensors
- Efficient 2D indexing
- Proper alignment
-
Memory Access:
- Coalesced global memory loads
- Efficient shared memory usage
- Minimal bank conflicts
-
Computation:
- Register-based accumulation
- Compile-time loop unrolling
- Efficient tile processing
This implementation achieves high performance through:
- Efficient use of LayoutTensor for memory access
- Optimal tiling strategy
- Proper thread synchronization
- Careful boundary handling