
Differtiable Programming in Mojo.
README
Nabla is a high-performance scientific computing and machine learning framework combining imperative and functional APIs. You can seamlessly drop custom Mojo kernels into the autodiff engine and automatically shard distributed workloads, built on Mojo & MAX.
Active Development: This is the main development branch with distributed SPMD execution and a refined lazy, MAX-native execution model. Read the docs: https://nablaml.com.
Nabla requires Modular nightly.
python -m venv venv
source venv/bin/activate
pip install --pre --extra-index-url https://whl.modular.com/nightly/simple/ modular nabla-mlGPU Support:
xcode-select --install).Installation of all dependencies (torch/jax for testing, mypy/black for linting, etc.)
git clone https://github.com/nabla-ml/nabla.git
cd nabla
python -m venv venv
source venv/bin/activate
pip install -r requirements-dev.txt
pip install -e ".[dev]"Define Python functions and compute gradients using trace-based automatic differentiation. Read more
import nabla
# Use Accelerator (GPU) or CPU for execution
with nabla.default_device(nabla.Accelerator()):
x = nabla.uniform((4, 8))
w = nabla.uniform((8, 16))
# Define loss function
def compute_loss(x, w):
return nabla.mean(nabla.relu(x @ w))
# Compute loss (implicit .realize() on print)
loss = compute_loss(x, w)
print("Loss:", loss)
# Compute gradients via backward replay
grad_x, grad_w = nabla.grad(compute_loss, argnums=(0, 1))(x, w)
print("Gradients:", grad_x.shape, grad_w.shape)Shard tensors on a logical mesh; operations automatically propagate sharding constraints. Read more
# Define 2×4 device mesh (Logical DP × TP)
mesh = nabla.DeviceMesh("my_mini_pod", (2, 4), ("dp", "tp"))
# Shard x on 'dp' (rows), w on 'tp' (columns)
x = nabla.shard(nabla.uniform((32, 128)), mesh, nabla.P("dp", None))
w = nabla.shard(nabla.uniform((128, 256)), mesh, nabla.P(None, "tp"))
def compute_loss(x, w):
return nabla.mean(nabla.relu(x @ w))
# Automatic AllReduce is inserted for 'tp' sum
loss = compute_loss(x, w)
print("Loss (Sharded):", loss)Nabla's core strength is its ability to drop down to Mojo for high-performance custom kernels, bridging the gap between high-level Python and bare-metal execution. Read more
Mojo Kernel (kernels/custom_kernel.mojo)
@compiler.register("my_kernel")
struct MyKernel:
@staticmethod
def execute[target: StaticString](
output: OutputTensor,
x: InputTensor[dtype = output.dtype, rank = output.rank],
ctx: DeviceContextPtr,
):
@parameter
fn add_one[W: Int](idx: IndexList[x.rank]) -> SIMD[x.dtype, W]:
return x.load[W](idx) + 1
foreach[add_one, target=target](output, ctx)Python Usage
class AddOneOp(nabla.UnaryOperation):
name = "my_kernel"
def kernel(self, x, **kwargs):
# Concise invocation: (func_name, path, inputs, out_types)
return nabla.call_custom_kernel("my_kernel", "./kernels", x, x.type)
x = nabla.Tensor.constant([1., 2., 3.])
y = AddOneOp()(x)Define complex distributed schedules like GPipe using vmap for parallel execution and ppermute for explicit data movement. Read more
# Parallel execution across 'num_stages'
@nabla.vmap(in_axes=(0, 0), spmd_axis_name="stage")
def stage_compute(x, w):
return nabla.relu(x @ w)
def pipeline_step(current_state, fresh_input, weights, mask_0):
# 1. Compute: Run all stages in parallel
computed = stage_compute(current_state, weights)
# 2. Communicate: Shift activations to the next stage (i -> i+1)
shifted = nabla.ppermute(computed, perm=[(i, (i + 1) % stages) for i in range(stages)])
# 3. Control: Stage 0 takes fresh input; others take shifted data
return nabla.where(mask_0, fresh_input, shifted)Compile functions once with symbolic dimensions to handle varying input sizes without recompilation.
# Compile once for ANY batch size (dim 0)
@nabla.compile(dynamic_dims={0: {0: "batch"}})
def square(x):
return x * x
x_small = nabla.uniform((2, 10))
x_large = nabla.uniform((128, 10))
res1 = square(x_small) # Triggers compilation
res2 = square(x_large) # Reuses compiled graph!Nabla relies on three core principles:
.realize() is called.
Apache-2.0 — see LICENSE
DETAILS