warp.prefix_sum()
Hardware-Optimized Parallel Scan
For warp-level parallel scan operations we can use prefix_sum()
to replace complex shared memory algorithms with hardware-optimized primitives. This powerful operation enables efficient cumulative computations, parallel partitioning, and advanced coordination algorithms that would otherwise require dozens of lines of shared memory and synchronization code.
Key insight: The prefix_sum() operation leverages hardware-accelerated parallel scan to compute cumulative operations across warp lanes with \(O(\log n)\) complexity, replacing complex multi-phase algorithms with single function calls.
What is parallel scan? Parallel scan (prefix sum) is a fundamental parallel primitive that computes cumulative operations across data elements. For addition, it transforms
[a, b, c, d]
into[a, a+b, a+b+c, a+b+c+d]
. This operation is essential for parallel algorithms like stream compaction, quicksort partitioning, and parallel sorting.
Key concepts
In this puzzle, you’ll master:
- Hardware-optimized parallel scan with
prefix_sum()
- Inclusive vs exclusive prefix sum patterns
- Warp-level stream compaction for data reorganization
- Advanced parallel partitioning combining multiple warp primitives
- Single-warp algorithm optimization replacing complex shared memory
This transforms multi-phase shared memory algorithms into elegant single-function calls, enabling efficient parallel scan operations without explicit synchronization.
1. Warp inclusive prefix sum
Configuration
- Vector size:
SIZE = WARP_SIZE
(32 or 64 depending on GPU) - Grid configuration:
(1, 1)
blocks per grid - Block configuration:
(WARP_SIZE, 1)
threads per block - Data type:
DType.float32
- Layout:
Layout.row_major(SIZE)
(1D row-major)
The prefix_sum
advantage
Traditional prefix sum requires complex multi-phase shared memory algorithms. In Puzzle 12, we implemented this the hard way with explicit shared memory management:
fn prefix_sum_simple[
layout: Layout
](
output: LayoutTensor[mut=False, dtype, layout],
a: LayoutTensor[mut=False, dtype, layout],
size: Int,
):
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
shared = tb[dtype]().row_major[TPB]().shared().alloc()
if global_i < size:
shared[local_i] = a[global_i]
barrier()
offset = 1
for i in range(Int(log2(Scalar[dtype](TPB)))):
var current_val: output.element_type = 0
if local_i >= offset and local_i < size:
current_val = shared[local_i - offset] # read
barrier()
if local_i >= offset and local_i < size:
shared[local_i] += current_val
barrier()
offset *= 2
if global_i < size:
output[global_i] = shared[local_i]
Problems with traditional approach:
- Memory overhead: Requires shared memory allocation
- Multiple barriers: Complex multi-phase synchronization
- Complex indexing: Manual stride calculation and boundary checking
- Poor scaling: \(O(\log n)\) phases with barriers between each
With prefix_sum()
, parallel scan becomes trivial:
# Hardware-optimized approach - single function call!
current_val = input[global_i]
scan_result = prefix_sum[exclusive=False](current_val)
output[global_i] = scan_result
Benefits of prefix_sum:
- Zero memory overhead: Hardware-accelerated computation
- No synchronization: Single atomic operation
- Hardware optimized: Leverages specialized scan units
- Perfect scaling: Works for any
WARP_SIZE
(32, 64, etc.)
Code to complete
Implement inclusive prefix sum using the hardware-optimized prefix_sum()
primitive.
Mathematical operation: Compute cumulative sum where each lane gets the sum of all elements up to and including its position: \[\Large \text{output}[i] = \sum_{j=0}^{i} \text{input}[j]\]
This transforms input data [1, 2, 3, 4, 5, ...]
into cumulative sums [1, 3, 6, 10, 15, ...]
, where each position contains the sum of all previous elements plus itself.
fn warp_inclusive_prefix_sum[
layout: Layout, size: Int
](
output: LayoutTensor[mut=False, dtype, layout],
input: LayoutTensor[mut=False, dtype, layout],
):
"""
Inclusive prefix sum using warp primitive: Each thread gets sum of all elements up to and including its position.
Compare this to Puzzle 12's complex shared memory + barrier approach.
Puzzle 12 approach:
- Shared memory allocation
- Multiple barrier synchronizations
- Log(n) iterations with manual tree reduction
- Complex multi-phase algorithm
Warp prefix_sum approach:
- Single function call!
- Hardware-optimized parallel scan
- Automatic synchronization
- O(log n) complexity, but implemented in hardware.
NOTE: This implementation only works correctly within a single warp (WARP_SIZE threads).
For multi-warp scenarios, additional coordination would be needed.
"""
global_i = block_dim.x * block_idx.x + thread_idx.x
# FILL ME IN (roughly 4 lines)
View full file: problems/p24/p24.mojo
Tips
1. Understanding prefix_sum parameters
The prefix_sum()
function has an important template parameter that controls the scan type.
Key questions:
- What’s the difference between inclusive and exclusive prefix sum?
- Which parameter controls this behavior?
- For inclusive scan, what should each lane output?
Hint: Look at the function signature and consider what “inclusive” means for cumulative operations.
2. Single warp limitation
This hardware primitive only works within a single warp. Consider the implications.
Think about:
- What happens if you have multiple warps?
- Why is this limitation important to understand?
- How would you extend this to multi-warp scenarios?
3. Data type considerations
The prefix_sum
function may require specific data types for optimal performance.
Consider:
- What data type does your input use?
- Does
prefix_sum
expect a specific scalar type? - How do you handle type conversions if needed?
Test the warp inclusive prefix sum:
uv run poe p24 --prefix-sum
pixi run p24 --prefix-sum
Expected output when solved:
WARP_SIZE: 32
SIZE: 32
output: [1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 66.0, 78.0, 91.0, 105.0, 120.0, 136.0, 153.0, 171.0, 190.0, 210.0, 231.0, 253.0, 276.0, 300.0, 325.0, 351.0, 378.0, 406.0, 435.0, 465.0, 496.0, 528.0]
expected: [1.0, 3.0, 6.0, 10.0, 15.0, 21.0, 28.0, 36.0, 45.0, 55.0, 66.0, 78.0, 91.0, 105.0, 120.0, 136.0, 153.0, 171.0, 190.0, 210.0, 231.0, 253.0, 276.0, 300.0, 325.0, 351.0, 378.0, 406.0, 435.0, 465.0, 496.0, 528.0]
âś… Warp inclusive prefix sum test passed!
Solution
fn warp_inclusive_prefix_sum[
layout: Layout, size: Int
](
output: LayoutTensor[mut=False, dtype, layout],
input: LayoutTensor[mut=False, dtype, layout],
):
"""
Inclusive prefix sum using warp primitive: Each thread gets sum of all elements up to and including its position.
Compare this to Puzzle 12's complex shared memory + barrier approach.
Puzzle 12 approach:
- Shared memory allocation
- Multiple barrier synchronizations
- Log(n) iterations with manual tree reduction
- Complex multi-phase algorithm
Warp prefix_sum approach:
- Single function call!
- Hardware-optimized parallel scan
- Automatic synchronization
- O(log n) complexity, but implemented in hardware.
NOTE: This implementation only works correctly within a single warp (WARP_SIZE threads).
For multi-warp scenarios, additional coordination would be needed.
"""
global_i = block_dim.x * block_idx.x + thread_idx.x
if global_i < size:
current_val = input[global_i]
# This one call replaces ~30 lines of complex shared memory logic from Puzzle 12!
# But it only works within the current warp (WARP_SIZE threads)
scan_result = prefix_sum[exclusive=False](
rebind[Scalar[dtype]](current_val)
)
output[global_i] = scan_result
This solution demonstrates how prefix_sum()
replaces complex multi-phase algorithms with a single hardware-optimized function call.
Algorithm breakdown:
if global_i < size:
current_val = input[global_i]
# This one call replaces ~30 lines of complex shared memory logic from Puzzle 12!
# But it only works within the current warp (WARP_SIZE threads)
scan_result = prefix_sum[exclusive=False](
rebind[Scalar[dtype]](current_val)
)
output[global_i] = scan_result
SIMT execution deep dive:
Input: [1, 2, 3, 4, 5, 6, 7, 8, ...]
Cycle 1: All lanes load their values simultaneously
Lane 0: current_val = 1
Lane 1: current_val = 2
Lane 2: current_val = 3
Lane 3: current_val = 4
...
Lane 31: current_val = 32
Cycle 2: prefix_sum[exclusive=False] executes (hardware-accelerated)
Lane 0: scan_result = 1 (sum of elements 0 to 0)
Lane 1: scan_result = 3 (sum of elements 0 to 1: 1+2)
Lane 2: scan_result = 6 (sum of elements 0 to 2: 1+2+3)
Lane 3: scan_result = 10 (sum of elements 0 to 3: 1+2+3+4)
...
Lane 31: scan_result = 528 (sum of elements 0 to 31)
Cycle 3: Store results
Lane 0: output[0] = 1
Lane 1: output[1] = 3
Lane 2: output[2] = 6
Lane 3: output[3] = 10
...
Mathematical insight: This implements the inclusive prefix sum operation: \[\Large \text{output}[i] = \sum_{j=0}^{i} \text{input}[j]\]
Comparison with Puzzle 12’s approach:
- Puzzle 12: ~30 lines of shared memory + multiple barriers + complex indexing
- Warp primitive: 1 function call with hardware acceleration
- Performance: Same \(O(\log n)\) complexity, but implemented in specialized hardware
- Memory: Zero shared memory usage vs explicit allocation
Evolution from Puzzle 12: This demonstrates the power of modern GPU architectures - what required careful manual implementation in Puzzle 12 is now a single hardware-accelerated primitive. The warp-level prefix_sum()
gives you the same algorithmic benefits with zero implementation complexity.
Why prefix_sum is superior:
- Hardware acceleration: Dedicated scan units on modern GPUs
- Zero memory overhead: No shared memory allocation required
- Automatic synchronization: No explicit barriers needed
- Perfect scaling: Works optimally for any
WARP_SIZE
Performance characteristics:
- Latency: ~1-2 cycles (hardware scan units)
- Bandwidth: Zero memory traffic (register-only operation)
- Parallelism: All
WARP_SIZE
lanes participate simultaneously - Scalability: \(O(\log n)\) complexity with hardware optimization
Important limitation: This primitive only works within a single warp. For multi-warp scenarios, you would need additional coordination between warps.
2. Warp partition
Configuration
- Vector size:
SIZE = WARP_SIZE
(32 or 64 depending on GPU) - Grid configuration:
(1, 1)
blocks per grid - Block configuration:
(WARP_SIZE, 1)
threads per block
Code to complete
Implement single-warp parallel partitioning using BOTH shuffle_xor
AND prefix_sum
primitives.
Mathematical operation: Partition elements around a pivot value, placing elements < pivot
on the left and elements >= pivot
on the right:
\[\Large \text{output} = [\text{elements} < \text{pivot}] \,|\, [\text{elements} \geq \text{pivot}]\]
Advanced algorithm: This combines two sophisticated warp primitives:
shuffle_xor()
: Butterfly pattern for warp-level reduction (count left elements)prefix_sum()
: Exclusive scan for position calculation within partitions
This demonstrates the power of combining multiple warp primitives for complex parallel algorithms within a single warp.
fn warp_partition[
layout: Layout, size: Int
](
output: LayoutTensor[mut=False, dtype, layout],
input: LayoutTensor[mut=False, dtype, layout],
pivot: Float32,
):
"""
Single-warp parallel partitioning using BOTH shuffle_xor AND prefix_sum.
This implements a warp-level quicksort partition step that places elements < pivot
on the left and elements >= pivot on the right.
ALGORITHM COMPLEXITY - combines two advanced warp primitives:
1. shuffle_xor(): Butterfly pattern for warp-level reductions
2. prefix_sum(): Warp-level exclusive scan for position calculation.
This demonstrates the power of warp primitives for sophisticated parallel algorithms
within a single warp (works for any WARP_SIZE: 32, 64, etc.).
Example with pivot=5:
Input: [3, 7, 1, 8, 2, 9, 4, 6]
Result: [3, 1, 2, 4, 7, 8, 9, 6] (< pivot | >= pivot).
"""
global_i = block_dim.x * block_idx.x + thread_idx.x
if global_i < size:
current_val = input[global_i]
# FILL ME IN (roughly 13 lines)
Tips
1. Multi-phase algorithm structure
This algorithm requires several coordinated phases. Think about the logical steps needed for partitioning.
Key phases to consider:
- How do you identify which elements belong to which partition?
- How do you calculate positions within each partition?
- How do you determine the total size of the left partition?
- How do you write elements to their final positions?
2. Predicate creation
You need to create boolean predicates to identify partition membership.
Think about:
- How do you represent “this element belongs to the left partition”?
- How do you represent “this element belongs to the right partition”?
- What data type should you use for predicates that work with
prefix_sum
?
3. Combining shuffle_xor and prefix_sum
This algorithm uses both warp primitives for different purposes.
Consider:
- What is
shuffle_xor
used for in this context? - What is
prefix_sum
used for in this context? - How do these two operations work together?
4. Position calculation
The trickiest part is calculating where each element should be written in the output.
Key insights:
- Left partition elements: What determines their final position?
- Right partition elements: How do you offset them correctly?
- How do you combine local positions with partition boundaries?
Test the warp partition:
uv run poe p24 --partition
pixi run p24 --partition
Expected output when solved:
WARP_SIZE: 32
SIZE: 32
output: HostBuffer([3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0])
expected: HostBuffer([3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 3.0, 1.0, 2.0, 4.0, 0.0, 3.0, 1.0, 4.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0, 7.0, 8.0, 9.0, 6.0, 10.0, 11.0, 12.0, 13.0])
pivot: 5.0
âś… Warp partition test passed!
Solution
fn warp_partition[
layout: Layout, size: Int
](
output: LayoutTensor[mut=False, dtype, layout],
input: LayoutTensor[mut=False, dtype, layout],
pivot: Float32,
):
"""
Single-warp parallel partitioning using BOTH shuffle_xor AND prefix_sum.
This implements a warp-level quicksort partition step that places elements < pivot
on the left and elements >= pivot on the right.
ALGORITHM COMPLEXITY - combines two advanced warp primitives:
1. shuffle_xor(): Butterfly pattern for warp-level reductions
2. prefix_sum(): Warp-level exclusive scan for position calculation.
This demonstrates the power of warp primitives for sophisticated parallel algorithms
within a single warp (works for any WARP_SIZE: 32, 64, etc.).
Example with pivot=5:
Input: [3, 7, 1, 8, 2, 9, 4, 6]
Result: [3, 1, 2, 4, 7, 8, 9, 6] (< pivot | >= pivot).
"""
global_i = block_dim.x * block_idx.x + thread_idx.x
if global_i < size:
current_val = input[global_i]
# Phase 1: Create warp-level predicates
predicate_left = Float32(1.0) if current_val < pivot else Float32(0.0)
predicate_right = Float32(1.0) if current_val >= pivot else Float32(0.0)
# Phase 2: Warp-level prefix sum to get positions within warp
warp_left_pos = prefix_sum[exclusive=True](predicate_left)
warp_right_pos = prefix_sum[exclusive=True](predicate_right)
# Phase 3: Get total left count using shuffle_xor reduction
warp_left_total = predicate_left
# Butterfly reduction to get total across the warp: dynamic for any WARP_SIZE
offset = WARP_SIZE // 2
while offset > 0:
warp_left_total += shuffle_xor(warp_left_total, offset)
offset //= 2
# Phase 4: Write to output positions
if current_val < pivot:
# Left partition: use warp-level position
output[Int(warp_left_pos)] = current_val
else:
# Right partition: offset by total left count + right position
output[Int(warp_left_total + warp_right_pos)] = current_val
This solution demonstrates advanced coordination between multiple warp primitives to implement sophisticated parallel algorithms.
Complete algorithm analysis:
if global_i < size:
current_val = input[global_i]
# Phase 1: Create warp-level predicates
predicate_left = Float32(1.0) if current_val < pivot else Float32(0.0)
predicate_right = Float32(1.0) if current_val >= pivot else Float32(0.0)
# Phase 2: Warp-level prefix sum to get positions within warp
warp_left_pos = prefix_sum[exclusive=True](predicate_left)
warp_right_pos = prefix_sum[exclusive=True](predicate_right)
# Phase 3: Get total left count using shuffle_xor reduction
warp_left_total = predicate_left
# Butterfly reduction to get total across the warp: dynamic for any WARP_SIZE
offset = WARP_SIZE // 2
while offset > 0:
warp_left_total += shuffle_xor(warp_left_total, offset)
offset //= 2
# Phase 4: Write to output positions
if current_val < pivot:
# Left partition: use warp-level position
output[Int(warp_left_pos)] = current_val
else:
# Right partition: offset by total left count + right position
output[Int(warp_left_total + warp_right_pos)] = current_val
Multi-phase execution trace (8-lane example, pivot=5, values [3,7,1,8,2,9,4,6]):
Initial state:
Lane 0: current_val=3 (< 5) Lane 1: current_val=7 (>= 5)
Lane 2: current_val=1 (< 5) Lane 3: current_val=8 (>= 5)
Lane 4: current_val=2 (< 5) Lane 5: current_val=9 (>= 5)
Lane 6: current_val=4 (< 5) Lane 7: current_val=6 (>= 5)
Phase 1: Create predicates
Lane 0: predicate_left=1.0, predicate_right=0.0
Lane 1: predicate_left=0.0, predicate_right=1.0
Lane 2: predicate_left=1.0, predicate_right=0.0
Lane 3: predicate_left=0.0, predicate_right=1.0
Lane 4: predicate_left=1.0, predicate_right=0.0
Lane 5: predicate_left=0.0, predicate_right=1.0
Lane 6: predicate_left=1.0, predicate_right=0.0
Lane 7: predicate_left=0.0, predicate_right=1.0
Phase 2: Exclusive prefix sum for positions
warp_left_pos: [0, 0, 1, 1, 2, 2, 3, 3]
warp_right_pos: [0, 0, 0, 1, 1, 2, 2, 3]
Phase 3: Butterfly reduction for left total
Initial: [1, 0, 1, 0, 1, 0, 1, 0]
After reduction: all lanes have warp_left_total = 4
Phase 4: Write to output positions
Lane 0: current_val=3 < pivot → output[0] = 3
Lane 1: current_val=7 >= pivot → output[4+0] = output[4] = 7
Lane 2: current_val=1 < pivot → output[1] = 1
Lane 3: current_val=8 >= pivot → output[4+1] = output[5] = 8
Lane 4: current_val=2 < pivot → output[2] = 2
Lane 5: current_val=9 >= pivot → output[4+2] = output[6] = 9
Lane 6: current_val=4 < pivot → output[3] = 4
Lane 7: current_val=6 >= pivot → output[4+3] = output[7] = 6
Final result: [3, 1, 2, 4, 7, 8, 9, 6] (< pivot | >= pivot)
Mathematical insight: This implements parallel partitioning with dual warp primitives: \[\Large \begin{align} \text{left\_pos}[i] &= \text{prefix\sum}{\text{exclusive}}(\text{predicate\_left}[i]) \\ \text{right\_pos}[i] &= \text{prefix\sum}{\text{exclusive}}(\text{predicate\_right}[i]) \\ \text{left\_total} &= \text{butterfly\_reduce}(\text{predicate\_left}) \\ \text{final\_pos}[i] &= \begin{cases} \text{left\_pos}[i] & \text{if } \text{input}[i] < \text{pivot} \\ \text{left\_total} + \text{right\_pos}[i] & \text{if } \text{input}[i] \geq \text{pivot} \end{cases} \end{align}\]
Why this multi-primitive approach works:
- Predicate creation: Identifies partition membership for each element
- Exclusive prefix sum: Calculates relative positions within each partition
- Butterfly reduction: Computes partition boundary (total left count)
- Coordinated write: Combines local positions with global partition structure
Algorithm complexity:
- Phase 1: \(O(1)\) - Predicate creation
- Phase 2: \(O(\log n)\) - Hardware-accelerated prefix sum
- Phase 3: \(O(\log n)\) - Butterfly reduction with
shuffle_xor
- Phase 4: \(O(1)\) - Coordinated write
- Total: \(O(\log n)\) with excellent constants
Performance characteristics:
- Communication steps: \(2 \times \log_2(\text{WARP_SIZE})\) (prefix sum + butterfly reduction)
- Memory efficiency: Zero shared memory, all register-based
- Parallelism: All lanes active throughout algorithm
- Scalability: Works for any
WARP_SIZE
(32, 64, etc.)
Practical applications: This pattern is fundamental to:
- Quicksort partitioning: Core step in parallel sorting algorithms
- Stream compaction: Removing null/invalid elements from data streams
- Parallel filtering: Separating data based on complex predicates
- Load balancing: Redistributing work based on computational requirements
Summary
The prefix_sum()
primitive enables hardware-accelerated parallel scan operations that replace complex multi-phase algorithms with single function calls. Through these two problems, you’ve mastered:
Core Prefix Sum Patterns
-
Inclusive Prefix Sum (
prefix_sum[exclusive=False]
):- Hardware-accelerated cumulative operations
- Replaces ~30 lines of shared memory code with single function call
- \(O(\log n)\) complexity with specialized hardware optimization
-
Advanced Multi-Primitive Coordination (combining
prefix_sum
+shuffle_xor
):- Sophisticated parallel algorithms within single warp
- Exclusive scan for position calculation + butterfly reduction for totals
- Complex partitioning operations with optimal parallel efficiency
Key Algorithmic Insights
Hardware Acceleration Benefits:
prefix_sum()
leverages dedicated scan units on modern GPUs- Zero shared memory overhead compared to traditional approaches
- Automatic synchronization without explicit barriers
Multi-Primitive Coordination:
# Phase 1: Create predicates for partition membership
predicate = 1.0 if condition else 0.0
# Phase 2: Use prefix_sum for local positions
local_pos = prefix_sum[exclusive=True](predicate)
# Phase 3: Use shuffle_xor for global totals
global_total = butterfly_reduce(predicate)
# Phase 4: Combine for final positioning
final_pos = local_pos + partition_offset
Performance Advantages:
- Hardware optimization: Specialized scan units vs software implementation
- Memory efficiency: Register-only operations vs shared memory allocation
- Scalable complexity: \(O(\log n)\) with hardware acceleration
- Single-warp optimization: Perfect for algorithms within
WARP_SIZE
limits
Practical Applications
These prefix sum patterns are fundamental to:
- Parallel scan operations: Cumulative sums, products, min/max scans
- Stream compaction: Parallel filtering and data reorganization
- Quicksort partitioning: Core parallel sorting algorithm building block
- Parallel algorithms: Load balancing, work distribution, data restructuring
The combination of prefix_sum()
and shuffle_xor()
demonstrates how modern GPU warp primitives can implement sophisticated parallel algorithms with minimal code complexity and optimal performance characteristics.