Help us improve
Share bugs, ideas, or general feedback.
From mlx-dev
Writes idiomatic MLX code for Apple Silicon ML, handling arrays, neural networks, training loops, lazy evaluation, unified memory, Metal GPU, and API differences from PyTorch/NumPy.
npx claudepluginhub joshuarweaver/cascade-code-languages-python --plugin ettrickshepherd-mlx-dev-skillHow this skill is triggered — by the user, by Claude, or both
Slash command
/mlx-dev:mlx-devThe summary Claude sees in its skill listing — used to decide when to auto-load this skill
Use `uv` for Python environment and package management:
Runs LLMs on Apple Silicon with MLX/mlx_lm: unified memory, 4-bit quantization, streaming generation, prompt caching. For M-series chips.
Guide for selecting and deploying on-device AI on Apple platforms: Foundation Models, Core ML, MLX Swift, and llama.cpp. Covers model conversion, quantization, structured output, and Neural Engine optimization.
Provides empirical rules for authoring PyTorch models targeting on-device execution on Apple platforms (Neural Engine, GPU). Covers op compatibility, BC1S layout, KV cache patterns, correctness testing via PSNR, and common debugging issues.
Share bugs, ideas, or general feedback.
Use uv for Python environment and package management:
# Install MLX
uv add mlx
# Run MLX scripts
uv run python train.py
# Run with specific dependencies
uv run --with mlx python script.py
Operations build a graph; nothing computes until mx.eval():
# CORRECT: Evaluate at iteration boundaries
for batch in dataset:
loss, grads = value_and_grad_fn(model, batch)
optimizer.update(model, grads)
mx.eval(loss, model.parameters()) # ALL computation here
# WRONG: Evaluating too frequently
for _ in range(100):
a = a + b
mx.eval(a) # Massive overhead!
Implicit eval triggers: print(a), a.item(), np.array(a), if a > 0:.
# Lists must be mx.array
a[[0, 1]] # ValueError!
a[mx.array([0, 1])] # Works
# Slice indices must be Python ints
i = mx.array(2)
x[i:i+2] # ValueError!
x[i.item():i.item()+2] # Works (forces eval)
# Slices create COPIES, not views (opposite of NumPy)
b = a[:]
b[2] = 0 # a is unchanged!
# Boolean mask READS not supported
a[mask] # Not supported - use mx.where()
# No bounds checking - out-of-bounds returns garbage
For accumulating updates, use at[] syntax:
a = a.at[idx].add(1) # Properly accumulates at duplicate indices
See references/array-indexing.md for complete patterns.
# Conv2d uses NHWC (not NCHW like PyTorch)
x_mlx = mx.array(x_torch.numpy().transpose(0, 2, 3, 1))
# Override __call__, not forward()
class MyModel(nn.Module):
def __call__(self, x): # NOT forward()
return self.layer(x)
# No dtype in constructors - use set_dtype()
layer = nn.Linear(10, 10)
layer.set_dtype(mx.bfloat16)
See references/neural-networks.md for layer equivalents.
a = mx.array([1.0], dtype=mx.float64)
mx.exp(a, stream=mx.gpu) # RuntimeError!
# Solutions:
mx.exp(a, stream=mx.cpu)
mx.exp(a.astype(mx.float32))
# bfloat16 from external sources gets misinterpreted
from ml_dtypes import bfloat16
x = np.array(1., dtype=bfloat16)
mx.array(x) # Returns complex64!
mx.array(x.astype(np.float32), dtype=mx.bfloat16) # Correct
See references/dtypes.md for full type support table.
from functools import partial
state = [model.state, optimizer.state, mx.random.state] # Include random!
@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss
# No print() in compiled functions - crashes during tracing
# String decoding triggers recompilation - decode outside loop
See references/compilation.md for recompilation triggers.
| Type | GPU | Notes |
|---|---|---|
| float32 | Yes | Default float |
| float16 | Yes | |
| bfloat16 | Yes | M3+ recommended |
| float64 | CPU only | GPU throws! |
| int8-64, uint8-64 | Yes | |
| complex64 | Partial | No matmul |
| PyTorch | MLX |
|---|---|
tensor.to('cuda') | Not needed (unified memory) |
nn.forward() | nn.__call__() |
| NCHW format | NHWC format |
torch.gather() | mx.take_along_axis() |
torch.scatter_add_() | arr.at[idx].add() |
np.nonzero() - restructure algorithmnp.unique() - pre-sort or use dictsarr[bool_mask] read - use mx.where()np.linalg.det(), np.linalg.lstsq()import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layers = [nn.Linear(784, 256), nn.Linear(256, 10)]
def __call__(self, x):
for layer in self.layers[:-1]:
x = mx.maximum(layer(x), 0)
return self.layers[-1](x)
def loss_fn(model, x, y):
return nn.losses.cross_entropy(model(x), y, reduction="mean")
model = Model()
optimizer = optim.AdamW(learning_rate=1e-3)
state = [model.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss
for epoch in range(num_epochs):
for x_batch, y_batch in dataloader:
loss = train_step(x_batch, y_batch)
mx.eval(state)
print(f"Epoch {epoch}: {loss.item():.4f}")