Overview
Implement a kernel that adds 10 to each position of a vector a
and stores it in output
.
Note: You have fewer threads per block than the size of a
.
Key concepts
In this puzzle, you’ll learn about:
- Using shared memory within thread blocks
- Synchronizing threads with barriers
- Managing block-local data storage
The key insight is understanding how shared memory provides fast, block-local storage that all threads in a block can access, requiring careful coordination between threads.
Configuration
- Array size:
SIZE = 8
elements - Threads per block:
TPB = 4
- Number of blocks: 2
- Shared memory:
TPB
elements per block
Notes:
- Shared memory: Fast storage shared by threads in a block
- Thread sync: Coordination using
barrier()
- Memory scope: Shared memory only visible within block
- Access pattern: Local vs global indexing
Warning: Each block can only have a constant amount of shared memory that threads in that block can read and write to. This needs to be a literal python constant, not a variable. After writing to shared memory you need to call barrier to ensure that threads do not cross.
Educational Note: In this specific puzzle, the barrier()
isn’t strictly necessary since each thread only accesses its own shared memory location. However, it’s included to teach proper shared memory synchronization patterns for more complex scenarios where threads need to coordinate access to shared data.
Code to complete
alias TPB = 4
alias SIZE = 8
alias BLOCKS_PER_GRID = (2, 1)
alias THREADS_PER_BLOCK = (TPB, 1)
alias dtype = DType.float32
fn add_10_shared(
output: UnsafePointer[Scalar[dtype]],
a: UnsafePointer[Scalar[dtype]],
size: Int,
):
shared = stack_allocation[
TPB,
Scalar[dtype],
address_space = AddressSpace.SHARED,
]()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# local data into shared memory
if global_i < size:
shared[local_i] = a[global_i]
# wait for all threads to complete
# works within a thread block
barrier()
# FILL ME IN (roughly 2 lines)
View full file: problems/p08/p08.mojo
Tips
- Wait for shared memory load with
barrier()
(educational - not strictly needed here) - Use
local_i
to access shared memory:shared[local_i]
- Use
global_i
for output:output[global_i]
- Add guard:
if global_i < size
Running the code
To test your solution, run the following command in your terminal:
uv run poe p08
pixi run p08
Your output will look like this if the puzzle isn’t solved yet:
out: HostBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0, 11.0])
Solution
fn add_10_shared(
output: UnsafePointer[Scalar[dtype]],
a: UnsafePointer[Scalar[dtype]],
size: Int,
):
shared = stack_allocation[
TPB,
Scalar[dtype],
address_space = AddressSpace.SHARED,
]()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# local data into shared memory
if global_i < size:
shared[local_i] = a[global_i]
# wait for all threads to complete
# works within a thread block
# Note: barrier is not strictly needed here since each thread only accesses its own shared memory location.
# However, it's included to teach proper shared memory synchronization patterns
# for more complex scenarios where threads need to coordinate access to shared data.
# For this specific puzzle, we can remove the barrier since each thread only accesses its own shared memory location.
barrier()
# process using shared memory
if global_i < size:
output[global_i] = shared[local_i] + 10
This solution demonstrates key concepts of shared memory usage in GPU programming:
-
Memory hierarchy
- Global memory:
a
andoutput
arrays (slow, visible to all blocks) - Shared memory:
shared
array (fast, thread-block local) - Example for 8 elements with 4 threads per block:
Global array a: [1 1 1 1 | 1 1 1 1] # Input: all ones Block (0): Block (1): shared[0..3] shared[0..3] [1 1 1 1] [1 1 1 1]
- Global memory:
-
Thread coordination
- Load phase:
Thread 0: shared[0] = a[0]=1 Thread 2: shared[2] = a[2]=1 Thread 1: shared[1] = a[1]=1 Thread 3: shared[3] = a[3]=1 barrier() ↓ ↓ ↓ ↓ # Wait for all loads
- Process phase: Each thread adds 10 to its shared memory value
- Result:
output[i] = shared[local_i] + 10 = 11
Note: In this specific case, the
barrier()
isn’t strictly necessary since each thread only writes to and reads from its own shared memory location (shared[local_i]
). However, it’s included for educational purposes to demonstrate proper shared memory synchronization patterns that are essential when threads need to access each other’s data. - Load phase:
-
Index mapping
- Global index:
block_dim.x * block_idx.x + thread_idx.x
Block 0 output: [11 11 11 11] Block 1 output: [11 11 11 11]
- Local index:
thread_idx.x
for shared memory accessBoth blocks process: 1 + 10 = 11
- Global index:
-
Memory access pattern
- Load: Global → Shared (coalesced reads of 1s)
- Sync:
barrier()
ensures all loads complete - Process: Add 10 to shared values
- Store: Write 11s back to global memory
This pattern shows how to use shared memory to optimize data access while maintaining thread coordination within blocks.