Simple Version
Implement a kernel that computes a prefix-sum over 1D LayoutTensor a
and stores it in 1D LayoutTensor output
.
Note: If the size of a
is greater than the block size, only store the sum of each block.
Configuration
- Array size:
SIZE = 8
elements - Threads per block:
TPB = 8
- Number of blocks: 1
- Shared memory:
TPB
elements
Notes:
- Data loading: Each thread loads one element using LayoutTensor access
- Memory pattern: Shared memory for intermediate results using
LayoutTensorBuild
- Thread sync: Coordination between computation phases
- Access pattern: Stride-based parallel computation
- Type safety: Leveraging LayoutTensorβs type system
Code to complete
alias TPB = 8
alias SIZE = 8
alias BLOCKS_PER_GRID = (1, 1)
alias THREADS_PER_BLOCK = (TPB, 1)
alias dtype = DType.float32
alias layout = Layout.row_major(SIZE)
fn prefix_sum_simple[
layout: Layout
](
output: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
size: Int,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
# FILL ME IN (roughly 18 lines)
View full file: problems/p12/p12.mojo
Tips
- Load data into
shared[local_i]
- Use
offset = 1
and double it each step - Add elements where
local_i >= offset
- Call
barrier()
between steps
Running the code
To test your solution, run the following command in your terminal:
uv run poe p12 --simple
pixi run p12 --simple
Your output will look like this if the puzzle isnβt solved yet:
out: DeviceBuffer([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
expected: HostBuffer([0.0, 1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0])
Solution
fn prefix_sum_simple[
layout: Layout
](
output: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
size: Int,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
shared = tb[dtype]().row_major[TPB]().shared().alloc()
if global_i < size:
shared[local_i] = a[global_i]
barrier()
offset = 1
for i in range(Int(log2(Scalar[dtype](TPB)))):
var current_val: output.element_type = 0
if local_i >= offset and local_i < size:
current_val = shared[local_i - offset] # read
barrier()
if local_i >= offset and local_i < size:
shared[local_i] += current_val
barrier()
offset *= 2
if global_i < size:
output[global_i] = shared[local_i]
The parallel (inclusive) prefix-sum algorithm works as follows:
Setup & Configuration
TPB
(Threads Per Block) = 8SIZE
(Array Size) = 8
Race Condition Prevention
The algorithm uses explicit synchronization to prevent read-write hazards:
- Read Phase: All threads first read the values they need into a local variable
current_val
- Synchronization:
barrier()
ensures all reads complete before any writes begin - Write Phase: All threads then safely write their computed values back to shared memory
This prevents the race condition that would occur if threads simultaneously read from and write to the same shared memory locations.
Alternative approach: Another solution to prevent race conditions is through double buffering, where you allocate twice the shared memory and alternate between reading from one buffer and writing to another. While this approach eliminates race conditions completely, it requires more shared memory and adds complexity. For educational purposes, we use the explicit synchronization approach as itβs more straightforward to understand.
Thread Mapping
thread_idx.x
: \([0, 1, 2, 3, 4, 5, 6, 7]\) (local_i
)block_idx.x
: \([0, 0, 0, 0, 0, 0, 0, 0]\)global_i
: \([0, 1, 2, 3, 4, 5, 6, 7]\) (block_idx.x * TPB + thread_idx.x
)
Initial Load to Shared Memory
Threads: Tβ Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
Input array: [0 1 2 3 4 5 6 7]
shared: [0 1 2 3 4 5 6 7]
β β β β β β β β
Tβ Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
Offset = 1: First Parallel Step
Active threads: \(T_1 \ldots T_7\) (where local_i β₯ 1
)
Read Phase: Each thread reads the value it needs:
Tβ reads shared[0] = 0 Tβ
reads shared[4] = 4
Tβ reads shared[1] = 1 Tβ reads shared[5] = 5
Tβ reads shared[2] = 2 Tβ reads shared[6] = 6
Tβ reads shared[3] = 3
Synchronization: barrier()
ensures all reads complete
Write Phase: Each thread adds its read value to its current position:
Before: [0 1 2 3 4 5 6 7]
Add: +0 +1 +2 +3 +4 +5 +6
| | | | | | |
Result: [0 1 3 5 7 9 11 13]
β β β β β β β
Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
Offset = 2: Second Parallel Step
Active threads: \(T_2 \ldots T_7\) (where local_i β₯ 2
)
Read Phase: Each thread reads the value it needs:
Tβ reads shared[0] = 0 Tβ
reads shared[3] = 5
Tβ reads shared[1] = 1 Tβ reads shared[4] = 7
Tβ reads shared[2] = 3 Tβ reads shared[5] = 9
Synchronization: barrier()
ensures all reads complete
Write Phase: Each thread adds its read value:
Before: [0 1 3 5 7 9 11 13]
Add: +0 +1 +3 +5 +7 +9
| | | | | |
Result: [0 1 3 6 10 14 18 22]
β β β β β β
Tβ Tβ Tβ Tβ
Tβ Tβ
Offset = 4: Third Parallel Step
Active threads: \(T_4 \ldots T_7\) (where local_i β₯ 4
)
Read Phase: Each thread reads the value it needs:
Tβ reads shared[0] = 0 Tβ reads shared[2] = 3
Tβ
reads shared[1] = 1 Tβ reads shared[3] = 6
Synchronization: barrier()
ensures all reads complete
Write Phase: Each thread adds its read value:
Before: [0 1 3 6 10 14 18 22]
Add: +0 +1 +3 +6
| | | |
Result: [0 1 3 6 10 15 21 28]
β β β β
Tβ Tβ
Tβ Tβ
Final Write to Output
Threads: Tβ Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
global_i: 0 1 2 3 4 5 6 7
output: [0 1 3 6 10 15 21 28]
β β β β β β β β
Tβ Tβ Tβ Tβ Tβ Tβ
Tβ Tβ
Key Implementation Details
Synchronization Pattern: Each iteration follows a strict read β sync β write pattern:
var current_val: out.element_type = 0
- Initialize local variablecurrent_val = shared[local_i - offset]
- Read phase (if conditions met)barrier()
- Explicit synchronization to prevent race conditionsshared[local_i] += current_val
- Write phase (if conditions met)barrier()
- Standard synchronization before next iteration
Race Condition Prevention: Without the explicit read-write separation, multiple threads could simultaneously access the same shared memory location, leading to undefined behavior. The two-phase approach with explicit synchronization ensures correctness.
Memory Safety: The algorithm maintains memory safety through:
- Bounds checking with
if local_i >= offset and local_i < size
- Proper initialization of the temporary variable
- Coordinated access patterns that prevent data races
The solution ensures correct synchronization between phases using barrier()
and handles array bounds checking with if global_i < size
. The final result produces the inclusive prefix sum where each element \(i\) contains \(\sum_{j=0}^{i} a[j]\).