Custom Operations: Applications in AI Models
Help us improve and tell us what you’d like us to build next.
Request a recipe topicREADME
In this recipe, we will cover:
We'll walk through two examples that
Let's get started.
Please make sure your system meets our system requirements.
To proceed, ensure you have the magic
CLI installed with the magic --version
to be 0.7.2 or newer:
curl -ssL https://magic.modular.com/ | bash
or update it via:
magic self-update
These examples can all be run on either a CPU or GPU. To run them on a GPU, ensure your system meets these GPU requirements:
magic
CLI:magic init custom-ops-ai-applications --from custom-ops-ai-applications
cd custom-ops-ai-applications
magic run top_k
magic run fused_attention
magic run benchmarks
AI models in MAX are built as computational graphs using the MAX Graph API. MAX contains within it a powerful graph compiler that can take these graphs and optimize them for best performance on a wide range of hardware.
Each node in a MAX Graph is defined by an operation that performs a calculation on zero or more inputs and produces one or more outputs. These inputs and outputs tend to be in the form of tensors, and the operations are usually data-parallel calculations that are accelerated on CPUs or GPUs. In MAX, these operations are written using Mojo, a Python-family language built for high-performance computation.
Large language models rely on token samplers to improve the quality of the text generated from the model, as well as add interesting variability to the output. One sampling technique is top-K token sampling, and this example provides both CPU and GPU implementations of this algorithm. The GPU implementation demonstrates how to accelerate the sampling via hardware features.
The following can be used to run the top-K token sampling demo:
magic run top_k
The file top_k.py
defines a block of text, then chooses three words and
builds a Numpy array with three batches for how often each "next word" appears
as percentages. The Numpy array is passed to the custom op, which returns two
arrays to order each batch/word by highest frequency. It uses a top_k
kernel
that runs on CPU, or MAX-compatible GPU if you have one attached. The GPU kernel
uses a warp-level algorithm to demonstrate using low-level GPU primitives, each
word/batch runs in parallel on a separate GPU block.
You can look at the kernels/top_k.mojo
file to see the differences between the
CPU and GPU implementations. Run magic run benchmarks
to see the performance
difference.
This demonstrates how you can build your own custom op for any specific
functionality you want to add to MAX's performant op implementations, using low
level GPU and CPU primitives. Note that it is a simplified version, MAX has it's
own mo.top_k
op which is more feature complete.
Modern Transformer-based language models are constructed around the attention mechanism. Optimizing how attention is performed is a key driver in improving large language model performance.
FlashAttention-2 is a memory-efficient attention algorithm that significantly improves the performance of transformer-based models by reducing memory bandwidth requirements and optimizing computation patterns. FlashAttention is particularly beneficial for:
In this example, you'll see how to implement FlashAttention-2 as a fused operation that runs on the GPU in MAX using Mojo.
To run the example, use the following command:
magic run fused_attention
The classic attention operation follows this general structure:
It consists of:
bmm
: Q x Transpose(K)
where Q
, K
both have shape [batchSize, numHeads, S, d]
and Q x K^t
has the shape [batchSize, numHeads, S, S]
softmax
bmm
: softmax(Q x K^t) x V
where V has the shape [batchSize, numHeads, S, d]
bmm
is short for batched matrix multiplication.
S
denotes the sequence length. Depending on the model, it can be as large as
O(10^3) - O(10^4)
. d
is the size per head in multi-head attention. It’s
usually a power of 2 like 64, 128, etc, and smaller than S
.
A limitation of the classic implementation is that it materializes an
intermediate matrix of shape [batchSize, numHeads, S, S]
. This introduces
O(S^2)
memory allocation and traffic.
FlashAttention optimizes the standard attention mechanism by:
Q
, K
, and V
matrices into
smaller blocks that fit in GPU shared memory, which is much faster than
global memory.These help maximize the locality and reduce DRAM (global memory) traffic.
This is the core of the fused FlashAttention kernel used in this example:
alias N = Q.shape[0]()
alias D = Q.shape[1]()
Q_tile = Q.tile[BN, D](block_idx.y, 0)
m_1 = (
LayoutTensor[q_dtype, Layout(BN, 1), MutableAnyOrigin]
.stack_allocation()
.fill(Scalar[q_dtype].MIN)
)
l_1 = (
LayoutTensor[q_dtype, Layout(BN, 1), MutableAnyOrigin]
.stack_allocation()
.fill(0)
)
O_i = (
LayoutTensor[q_dtype, Layout.row_major(BN, BD), MutableAnyOrigin]
.stack_allocation()
.fill(0)
)
alias BN_1 = 8
@parameter
for tile_n_idx in range(N // BN_1):
K_tile = K.tile[BN_1, D](tile_n_idx, 0)
V_tile = V.tile[BN_1, BD](tile_n_idx, block_idx.x)
S = matmul["gpu", transpose_b=True](Q_tile, K_tile)
m_2 = max(m_1, rebind[__type_of(m_1)](max[axis=1](S)))
l_2 = exp(m_1 - m_2) * l_1 + sum[axis=1](exp(S - m_2))
P = exp(S - m_2) / l_2
O_i = O_i * (l_1 / l_2) * exp(m_1 - m_2) + matmul["gpu"](P, V_tile)
m_1 = m_2
l_1 = rebind[__type_of(l_1)](l_2)
O.tile[BN, BD](block_idx.y, block_idx.x).copy_from(O_i)
Note how the Mojo abstractions present in MAX allow for this algorithm to be expressed very closely to the description in the original research paper.
In this recipe, we've demonstrated how to create custom MAX Graph operations that perform functions important in modern AI models: top-K token sampling and the FlashAttention-2 attention layer. Each provides examples of how complex calculations can be constructed using MAX and Mojo and targeted towards hardware features in GPUs.
Follow our tutorial for building a custom operation from scratch.
Explore MAX's documentation for additional
features. The gpu
module has
detail on Mojo's GPU programming functions and types, and the documentation
on @compiler.register
shows how to register custom graph operations.
Join our Modular Forum and Discord community to share your experiences and get support.
We're excited to see what you'll build with MAX! Share your projects and experiences with us using #ModularAI
on social media.
DETAILS
THE CODE
custom-ops-ai-applications
AUTHOR
Brad Larson
AVAILABLE TASKS
magic run top_k
magic run fused_attention
magic run benchmarks
PROBLEMS WITH THE CODE?
File an Issue
TAGS
Help us improve and tell us what you’d like us to build next.
Request a recipe topic