Puzzle 16: Softmax Op
Overview
In this puzzle, we’ll implement the softmax function as a custom MAX Graph operation. Softmax takes a vector of real numbers and normalizes it into a probability distribution.
Mathematically, the softmax function is defined as:
$$\Large \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$$
Where:
- \(x_i\) is the \(i\)-th element of the input vector
- \(n\) is the length of the input vector
However, this direct implementation can lead to numerical overflow issues when values are large. To address this, we use a more numerically stable version:
$$\Large \text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_{j=1}^{n} e^{x_j - \max(x)}}$$
Our GPU implementation uses parallel reduction for both finding the maximum value and computing the sum of exponentials, making it highly efficient for large vectors.
Key concepts
- Parallel reduction for efficient maximum and sum calculations
- Numerical stability through max-subtraction technique
- Shared memory usage for thread communication
- Custom MAX Graph operation integration with Python
- Thread synchronization with barriers
Configuration
- Vector size: \(\text{SIZE} = 128\)
- Threads per block: \(\text{TPB} = 128\)
- Grid dimensions: \(1 \times 1\) block
- Shared memory: Two shared variables for max and sum
Layout configuration:
- Input tensor:
Layout.row_major(SIZE)
- Output tensor:
Layout.row_major(SIZE)
- Custom op parameters:
{"input_size": input_tensor.shape[0]}
Key aspects of this puzzle include:
- Numerical stability: Understanding how to handle potential numerical issues
- Parallel reductions: Using shared memory for efficient max and sum calculations
- Custom op integration: Completing the Python interface for our Mojo GPU kernel
- Testing and verification: Ensuring our implementation matches the expected results
Our softmax custom operation will:
- Accept NumPy arrays from Python
- Process them efficiently on the GPU
- Return normalized probability distributions
- Match the results of SciPy’s softmax implementation
Code to complete
To complete this puzzle, you need to implement both the GPU and CPU kernels in the Mojo file and complete the graph definition in the Python code.
1. Implement the GPU kernel:
from gpu import thread_idx, block_idx, block_dim, barrier
from gpu.host import DeviceContext, HostBuffer, DeviceBuffer
from layout import Layout, LayoutTensor
from layout.tensor_builder import LayoutTensorBuild as tb
from math import exp
from utils.numerics import max_finite, min_finite
alias SIZE = 128
alias TPB = 128
alias BLOCKS_PER_GRID = (1, 1)
alias THREADS_PER_BLOCK = (TPB, 1)
alias layout = Layout.row_major(SIZE)
alias dtype = DType.float32
fn softmax_gpu_kernel[
layout: Layout,
input_size: Int,
dtype: DType = DType.float32,
](
out: LayoutTensor[mut=True, dtype, layout],
input: LayoutTensor[mut=False, dtype, layout],
):
# FILL IN (roughly 31 lines)
...
View full file: problems/p16/op/softmax.mojo
Tips
- Use shared memory for both the maximum value and sum to ensure all threads can access these values
- Remember to call
barrier()
at appropriate points to synchronize threads - Implement parallel reduction by having each thread process a portion of the input array
- Use a tree-based reduction pattern to minimize thread divergence
- Handle out-of-bounds access carefully, especially for large inputs
- For numerical stability, calculate \(e^{x_i - max}\) instead of \(e^{x_i}\)
2. Implement the CPU kernel:
fn softmax_cpu_kernel[
layout: Layout,
input_size: Int,
dtype: DType = DType.float32,
](
out: LayoutTensor[dtype, layout, MutableAnyOrigin],
input: LayoutTensor[dtype, layout, MutableAnyOrigin],
):
# FILL IN (roughly 10 lines)
...
View full file: problems/p16/op/softmax.mojo
Tips
- Create a sequential implementation that follows the same mathematical steps as the GPU version
- First find the maximum value across all inputs
- Then compute \(e^{x_i - max}\) for each element and accumulate the sum
- Finally, normalize by dividing each element by the sum
- Use scalar operations since we don’t have parallel threads in the CPU implementation
Test the CPU and GPU kernels
uv run poe p16-test-kernels
pixi run p16-test-kernels
when done correctly you’ll see
Total Discovered Tests: 1
Passed : 1 (100.00%)
Failed : 0 (0.00%)
Skipped: 0 (0.00%)
3. Complete the graph definition:
from pathlib import Path
import numpy as np
from max.driver import CPU, Accelerator, Device, Tensor, accelerator_count
from max.dtype import DType
from max.engine import InferenceSession
from max.graph import DeviceRef, Graph, TensorType, ops
from numpy.typing import NDArray
from scipy.special import softmax as scipy_softmax
def softmax(
input: NDArray[np.float32],
session: InferenceSession,
device: Device,
) -> Tensor:
dtype = DType.float32
input_tensor = Tensor.from_numpy(input).to(device)
mojo_kernels = Path(__file__).parent / "op"
with Graph(
"softmax_graph",
input_types=[
TensorType(
dtype,
shape=input_tensor.shape,
device=DeviceRef.from_device(device),
),
],
custom_extensions=[mojo_kernels],
) as graph:
# FILL IN (roughly 4 unformatted lines)
pass
View full file: problems/p16/p16.py
Tips
- Use
graph.inputs[0]
to access the input tensor passed to the graph - Call
ops.custom()
with the name matching your registered custom op (“softmax”) - Pass the input tensor as a value to the custom operation
- Specify the output type to match the input shape
- Include the “input_size” parameter which is required by the kernel
- Set
graph.outputs
to a list containing your operation’s output tensor
You can run the puzzle with:
uv run poe p16
pixi run p16
When successful, you should see output similar to on CPU and GPU:
Input shape: (128,)
First few random input values: [ 1.1810775 0.60472375 0.5718309 0.6644599 -0.08899796]
Compiling softmax graph on Device(type=cpu,id=0)
Executing softmax on Device(type=cpu,id=0)
====================================================================================================
Compiling softmax graph on Device(type=gpu,id=0)
Executing softmax on Device(type=gpu,id=0)
====================================================================================================
First few softmax results on CPU (custom Mojo kernel): [0.01718348 0.00965615 0.0093437 0.01025055 0.0048253 ]
First few softmax results on GPU (custom Mojo kernel): [0.01718348 0.00965615 0.0093437 0.01025055 0.0048253 ]
First few expected results (SciPy calculation): [0.01718348 0.00965615 0.0093437 0.01025055 0.0048253 ]
Verification passed: Custom kernel results match SciPy calculation
Sum of all probabilities on CPU: 1.0
Sum of all probabilities on GPU: 1.0
This indicates that your custom MAX Graph operation correctly implements the softmax algorithm and produces a valid probability distribution.
Solution
To solve this puzzle, we need to implement both the Mojo kernels (GPU and CPU) and the Python graph definition for our softmax custom operation. Similar to what we did in Puzzle 15, we’re creating a bridge between Python’s ecosystem and Mojo’s GPU-accelerated computing capabilities.
The softmax operation we’re implementing is mathematically defined as:
$$\Large \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j=1}^{n} e^{x_j}}$$
However, to prevent numerical overflow, we use the more stable form:
$$\Large \text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_{j=1}^{n} e^{x_j - \max(x)}}$$
GPU kernel implementation:
fn softmax_gpu_kernel[
layout: Layout,
input_size: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[mut=True, dtype, layout],
input: LayoutTensor[mut=False, dtype, layout],
):
shared_max = tb[dtype]().row_major[TPB]().shared().alloc()
shared_sum = tb[dtype]().row_major[TPB]().shared().alloc()
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
var thread_max: Scalar[dtype] = min_finite[dtype]()
if global_i < input_size:
thread_max = rebind[Scalar[dtype]](input[global_i])
shared_max[local_i] = thread_max
barrier()
# Parallel reduction to find max
stride = TPB // 2
while stride > 0:
if local_i < stride:
shared_max[local_i] = max(
shared_max[local_i], shared_max[local_i + stride]
)
barrier()
stride = stride // 2
block_max = shared_max[0]
var exp_val: Scalar[dtype] = 0.0
if global_i < input_size:
exp_val = rebind[Scalar[dtype]](exp(input[global_i] - block_max))
output[global_i] = exp_val
shared_sum[local_i] = exp_val
barrier()
# Parallel reduction for sum
stride = TPB // 2
while stride > 0:
if local_i < stride:
shared_sum[local_i] += shared_sum[local_i + stride]
barrier()
stride = stride // 2
block_sum = shared_sum[0]
# Normalize by sum
if global_i < input_size:
output[global_i] = output[global_i] / block_sum
Kernel signature and memory management
fn softmax_gpu_kernel[
layout: Layout,
input_size: Int,
dtype: DType = DType.float32,
](
out: LayoutTensor[mut=True, dtype, layout],
input: LayoutTensor[mut=False, dtype, layout],
)
The kernel is parameterized with:
- Common layout parameter for both input and output tensors
- Vector size as an Integer parameter
- Configurable data type with float32 as default
- Mutable output tensor for in-place computation
- Non-mutable input tensor (mut=False)
Shared memory allocation
shared_max = tb[dtype]().row_major[TPB]().shared().alloc()
shared_sum = tb[dtype]().row_major[TPB]().shared().alloc()
The kernel allocates two shared memory buffers:
shared_max
: For parallel maximum finding reductionshared_sum
: For parallel sum computation- Both use
TPB
(Threads Per Block = 128) as their size - Shared memory provides fast access for all threads within a block
Thread indexing
global_i = block_dim.x * block_idx.x + thread_idx.x
local_i = thread_idx.x
Each thread computes:
global_i
: Its global index in the entire computation spacelocal_i
: Its local index within the current thread block This mapping ensures each thread processes exactly one input element.
Maximum-finding phase
var thread_max: Scalar[dtype] = min_finite[dtype]()
if global_i < input_size:
thread_max = rebind[Scalar[dtype]](input[global_i])
shared_max[local_i] = thread_max
barrier()
This initializes each thread with:
- The minimum finite value for elements outside the valid range
- The actual input value for threads that map to valid elements
- Storage in shared memory for the reduction process
- A barrier synchronization to ensure all threads complete memory writes
Parallel max reduction
stride = TPB // 2
while stride > 0:
if local_i < stride:
shared_max[local_i] = max(shared_max[local_i], shared_max[local_i + stride])
barrier()
stride = stride // 2
This implements a parallel tree-reduction pattern:
- Start with
stride = 64
(half ofTPB
) - Each active thread compares two values separated by the stride
- Store the maximum in the lower index
- Synchronize all threads with a barrier
- Halve the stride and repeat
- After \(\log_2(TPB)\) steps, shared_max[0] contains the global maximum
This logarithmic reduction is significantly faster than a linear scan on large inputs.
Exponentiation with numerical stability
block_max = shared_max[0]
var exp_val: Scalar[dtype] = 0.0
if global_i < input_size:
exp_val = rebind[Scalar[dtype]](exp(input[global_i] - block_max))
out[global_i] = exp_val
Each thread:
- Reads the global maximum from shared memory
- Subtracts it from its input value before taking the exponential
- This subtraction is crucial for numerical stability - it prevents overflow
- The largest exponent becomes \(e^0 = 1\), and all others are \(e^{negative} < 1\)
- Stores the intermediate result in the output buffer
Parallel sum reduction
shared_sum[local_i] = exp_val
barrier()
stride = TPB // 2
while stride > 0:
if local_i < stride:
shared_sum[local_i] += shared_sum[local_i + stride]
barrier()
stride = stride // 2
The second reduction phase:
- Stores all exponential values in shared memory
- Uses the same tree-based reduction pattern as for max
- But performs addition instead of maximum comparison
- After \(\log_2(TPB)\) steps,
shared_sum[0]
contains the total sum of all exponentials
Final normalization
block_sum = shared_sum[0]
if global_i < input_size:
out[global_i] = out[global_i] / block_sum
Each thread:
- Reads the total sum from shared memory
- Divides its exponential value by this sum
- Writes the normalized probability to the output buffer
- This produces a valid probability distribution that sums to 1
Performance characteristics
The implementation has excellent performance characteristics:
- Complexity: \(O(\log n)\) for both max and sum calculations vs \(O(n)\) in a sequential approach
- Memory efficiency: Uses only \(2 \times TPB\) elements of shared memory
- Work efficiency: Each thread performs approximately \(2 \times \log_2(n)\) operations
- Load balancing: Each thread handles the same amount of work
- Synchronization: Uses minimal barriers, only where necessary
- Memory access: Coalesced global memory access pattern for optimal bandwidth
The algorithm is also numerically robust, handling potential overflow/underflow cases by applying the max-subtraction technique that maintains precision across the wide range of values common in neural network activations.
CPU fallback implementation:
fn softmax_cpu_kernel[
layout: Layout,
input_size: Int,
dtype: DType = DType.float32,
](
output: LayoutTensor[dtype, layout, MutableAnyOrigin],
input: LayoutTensor[dtype, layout, MutableAnyOrigin],
):
var max_val: Scalar[dtype] = min_finite[dtype]()
for i in range(input_size):
max_val = max(max_val, rebind[Scalar[dtype]](input[i]))
var sum_exp: Scalar[dtype] = 0.0
for i in range(input_size):
var exp_val = rebind[Scalar[dtype]](exp(input[i] - max_val))
output[i] = exp_val
sum_exp += exp_val
for i in range(input_size):
output[i] = output[i] / sum_exp
-
Maximum Finding:
var max_val: Scalar[dtype] = min_finite[dtype]() for i in range(input_size): max_val = max(max_val, rebind[Scalar[dtype]](input[i]))
We initialize with the minimum finite value and perform a linear scan through the array, keeping track of the maximum value encountered. This has \(O(n)\) complexity but works efficiently on CPU where we don’t have many cores to parallelize across.
-
Exponential Computation and Summation:
var sum_exp: Scalar[dtype] = 0.0 for i in range(input_size): var exp_val = rebind[Scalar[dtype]](exp(input[i] - max_val)) out[i] = exp_val sum_exp += exp_val
We compute \(e^{x_i - max}\) for each element, store the result in the output buffer, and accumulate the sum \(\sum_{j=1}^{n} e^{x_j - max}\) in a single pass. This approach minimizes memory operations compared to using separate loops.
-
Normalization:
for i in range(input_size): out[i] = out[i] / sum_exp
Finally, we normalize each element by dividing by the sum, producing a proper probability distribution according to the softmax formula:
$$\Large \text{softmax}(x_i) = \frac{e^{x_i - \max(x)}}{\sum_{j=1}^{n} e^{x_j - \max(x)}}$$
The CPU implementation uses the same numerical stability technique (subtracting the maximum) but with sequential operations rather than parallel ones. It’s simpler than the GPU version since it doesn’t need to handle shared memory or thread synchronization, but it’s also less efficient for large inputs.
Both implementations are registered with MAX Graph’s custom operation system through the @compiler.register("softmax")
decorator, allowing seamless execution on either device type based on availability.
Python integration:
with Graph(
"softmax_graph",
input_types=[
TensorType(
dtype,
shape=input_tensor.shape,
device=DeviceRef.from_device(device),
),
],
custom_extensions=[mojo_kernels],
) as graph:
input_value = graph.inputs[0]
# The output shape is the same as the input for softmax
# Note: the name must match the name used in `@compiler.register("softmax")` in op/softmax.mojo
output = ops.custom(
name="softmax",
values=[input_value],
out_types=[
TensorType(
dtype=input_value.tensor.dtype,
shape=input_value.tensor.shape,
device=DeviceRef.from_device(device),
)
],
parameters={
"input_size": input_tensor.shape[0],
"dtype": dtype,
},
)[0].tensor
graph.output(output)
-
Graph Setup and Configuration:
with Graph( "softmax_graph", input_types=[ TensorType( dtype, shape=input_tensor.shape, device=DeviceRef.from_device(device), ), ], custom_extensions=[mojo_kernels], ) as graph:
This creates a computation graph named “softmax_graph” that:
- Defines the input tensor type with proper dtype and shape
- Maps the tensor to the target device (CPU or GPU)
- Loads our custom Mojo operations from the specified directory
- The
custom_extensions
parameter is crucial for linking to our Mojo implementation
-
Custom Operation Configuration:
output = ops.custom( name="softmax", values=[input_value], out_types=[ TensorType( dtype=input_value.tensor.dtype, shape=input_value.tensor.shape, device=DeviceRef.from_device(device), ) ], parameters={ "input_size": input_tensor.shape[0], "dtype": dtype, }, )[0].tensor
This sets up our custom operation with:
- Name matching the
@compiler.register("softmax")
in our Mojo code - Input values passed as a list
- Output type definition matching the input shape and type
- Parameters required by our kernel, including the vector size and data type
- We extract the tensor from the first returned element with
[0].tensor
- Name matching the
-
Graph Output Definition:
graph.output(output)
This registers our operation’s result as the graph’s output.
The main script includes comprehensive testing that:
- Generates random input data:
np.random.randn(INPUT_SIZE).astype(np.float32)
- Calculates expected results with SciPy:
scipy_softmax(input_array)
- Verifies numerical accuracy:
np.testing.assert_allclose(..., rtol=1e-5)
- Confirms the output is a valid probability distribution:
np.sum(result.to_numpy())
This implementation showcases the power of MAX Graph for integrating high-performance Mojo kernels with Python’s scientific computing ecosystem, providing both efficiency and ease of use.