Write correct, idiomatic Apple MLX code for Apple Silicon ML. Use when working with MLX arrays, neural networks, training loops, lazy evaluation, unified memory, mx.eval, mx.compile, Metal GPU, memory optimization, quantization, or Apple Silicon performance. Covers critical API differences from PyTorch/NumPy, array indexing gotchas (lists must be mx.array, slices create copies), NHWC format for Conv2d, __call__ not forward(), float64 CPU-only, mlx-lm integration, and debugging patterns.
Generates Apple MLX code for neural networks, training loops, and Metal GPU optimization on Apple Silicon.
npx claudepluginhub luqmannurhakimbazman/kapitan-marketplaceThis skill inherits all available tools. When active, it can use any tool Claude has access to.
references/array-indexing.mdreferences/compilation.mdreferences/dtypes.mdreferences/error-decoder.mdreferences/gradients.mdreferences/memory-management.mdreferences/neural-networks.mdreferences/pytorch-migration.mdreferences/random.mdscripts/check_memory.pyUse uv for Python environment and package management:
# Install MLX
uv add mlx
# Run MLX scripts
uv run python train.py
# Run with specific dependencies
uv run --with mlx python script.py
Operations build a graph; nothing computes until mx.eval():
# CORRECT: Evaluate at iteration boundaries
for batch in dataset:
loss, grads = value_and_grad_fn(model, batch)
optimizer.update(model, grads)
mx.eval(loss, model.parameters()) # ALL computation here
# WRONG: Evaluating too frequently
for _ in range(100):
a = a + b
mx.eval(a) # Massive overhead!
Implicit eval triggers: print(a), a.item(), np.array(a), if a > 0:.
# Lists must be mx.array
a[[0, 1]] # ValueError!
a[mx.array([0, 1])] # Works
# Slice indices must be Python ints
i = mx.array(2)
x[i:i+2] # ValueError!
x[i.item():i.item()+2] # Works (forces eval)
# Slices create COPIES, not views (opposite of NumPy)
b = a[:]
b[2] = 0 # a is unchanged!
# Boolean mask READS not supported
a[mask] # Not supported - use mx.where()
# No bounds checking - out-of-bounds returns garbage
For accumulating updates, use at[] syntax:
a = a.at[idx].add(1) # Properly accumulates at duplicate indices
See references/array-indexing.md for complete patterns.
# Conv2d uses NHWC (not NCHW like PyTorch)
x_mlx = mx.array(x_torch.numpy().transpose(0, 2, 3, 1))
# Override __call__, not forward()
class MyModel(nn.Module):
def __call__(self, x): # NOT forward()
return self.layer(x)
# No dtype in constructors - use set_dtype()
layer = nn.Linear(10, 10)
layer.set_dtype(mx.bfloat16)
See references/neural-networks.md for layer equivalents.
a = mx.array([1.0], dtype=mx.float64)
mx.exp(a, stream=mx.gpu) # RuntimeError!
# Solutions:
mx.exp(a, stream=mx.cpu)
mx.exp(a.astype(mx.float32))
# bfloat16 from external sources gets misinterpreted
from ml_dtypes import bfloat16
x = np.array(1., dtype=bfloat16)
mx.array(x) # Returns complex64!
mx.array(x.astype(np.float32), dtype=mx.bfloat16) # Correct
See references/dtypes.md for full type support table.
from functools import partial
state = [model.state, optimizer.state, mx.random.state] # Include random!
@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss
# No print() in compiled functions - crashes during tracing
# String decoding triggers recompilation - decode outside loop
See references/compilation.md for recompilation triggers.
| Type | GPU | Notes |
|---|---|---|
| float32 | Yes | Default float |
| float16 | Yes | |
| bfloat16 | Yes | M3+ recommended |
| float64 | CPU only | GPU throws! |
| int8-64, uint8-64 | Yes | |
| complex64 | Partial | No matmul |
| PyTorch | MLX |
|---|---|
tensor.to('cuda') | Not needed (unified memory) |
nn.forward() | nn.__call__() |
| NCHW format | NHWC format |
torch.gather() | mx.take_along_axis() |
torch.scatter_add_() | arr.at[idx].add() |
np.nonzero() - restructure algorithmnp.unique() - pre-sort or use dictsarr[bool_mask] read - use mx.where()np.linalg.det(), np.linalg.lstsq()import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layers = [nn.Linear(784, 256), nn.Linear(256, 10)]
def __call__(self, x):
for layer in self.layers[:-1]:
x = mx.maximum(layer(x), 0)
return self.layers[-1](x)
def loss_fn(model, x, y):
return nn.losses.cross_entropy(model(x), y, reduction="mean")
model = Model()
optimizer = optim.AdamW(learning_rate=1e-3)
state = [model.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss
for epoch in range(num_epochs):
for x_batch, y_batch in dataloader:
loss = train_step(x_batch, y_batch)
mx.eval(state)
print(f"Epoch {epoch}: {loss.item():.4f}")
Expert guidance for Next.js Cache Components and Partial Prerendering (PPR). **PROACTIVE ACTIVATION**: Use this skill automatically when working in Next.js projects that have `cacheComponents: true` in their next.config.ts/next.config.js. When this config is detected, proactively apply Cache Components patterns and best practices to all React Server Component implementations. **DETECTION**: At the start of a session in a Next.js project, check for `cacheComponents: true` in next.config. If enabled, this skill's patterns should guide all component authoring, data fetching, and caching decisions. **USE CASES**: Implementing 'use cache' directive, configuring cache lifetimes with cacheLife(), tagging cached data with cacheTag(), invalidating caches with updateTag()/revalidateTag(), optimizing static vs dynamic content boundaries, debugging cache issues, and reviewing Cache Component implementations.
Applies Anthropic's official brand colors and typography to any sort of artifact that may benefit from having Anthropic's look-and-feel. Use it when brand colors or style guidelines, visual formatting, or company design standards apply.
Creating algorithmic art using p5.js with seeded randomness and interactive parameter exploration. Use this when users request creating art using code, generative art, algorithmic art, flow fields, or particle systems. Create original algorithmic art rather than copying existing artists' work to avoid copyright violations.