BACK TO ENGINEERING
Engineering 14 min read

Wiring Apple's Neural Engine Into a Zig Inference Runtime: What We Learned Building ANE Dispatch

We're building an on-device LLM inference runtime in Zig. No dependencies. Single static binary. Runs transformer models on Apple Silicon by reading a schedule that maps each operation to the optimal compute engine — CPU, GPU, or Neural Engine.

The GPU path was straightforward. Metal shared memory gives you zero-copy buffers between CPU and GPU. Write from one side, read from the other, call syncGpu() at the boundary. Done.

The ANE path was not straightforward.


I – Why ANE Matters for On-Device Inference

Apple's Neural Engine is the most underutilized compute resource on every Mac and iPhone shipping today. The M-series ANE delivers 15–18 TFLOPS of fp16 throughput. For context, the GPU on an M2 Max delivers roughly 13.6 TFLOPS fp32. The ANE is faster for the operations that dominate transformer inference — dense matrix multiplications — and it draws significantly less power doing it.

But almost nobody uses it for LLM inference. The reason is simple: Apple doesn't provide a public low-level API. CoreML exists, but it's a high-level framework designed for classification and detection models with fixed input shapes. It's not designed for autoregressive generation where you're running the same model thousands of times with a KV cache that grows each step.

Our approach bypasses CoreML entirely. We compile ANE kernels directly through the private _ANECompile interface, pack data into the spatial format ANE expects, and evaluate through _ANEEvaluate. This is what Apple's own ML frameworks do internally — we're just doing it from Zig through a thin C bridge.

The question this article answers: what happens when you actually try to wire ANE dispatch into a real inference pipeline?


II – The Existing Architecture

Our runtime, zml-flow, has four layers:

  1. Genome parser — reads model weights from safetensors, infers the architecture (GPT-2, Qwen, Llama), and produces a computation graph we call a "genome."
  2. HEFT scheduler — assigns each operation in the graph to the optimal engine (CPU-AMX, CPU-NEON, GPU-Metal, or ANE) based on calibration data.
  3. Schedule exporter — writes the assignment as schedule.json, a flat list of operations with engine IDs, shapes, weight bindings, and timing estimates.
  4. Runtime executor — reads schedule.json, memory-maps the weights, and dispatches operations in topological order.
flowchart LR
    A["safetensors"] --> B["Genome Parser"]
    B -- "computation graph" --> C["HEFT Scheduler"]
    C -- "engine assignments" --> D["Schedule Exporter"]
    D -- "schedule.json" --> E["Runtime Executor"]
    E --> F["CPU-AMX"]
    E --> G["GPU-Metal"]
    E --> H["ANE"]

The executor was built in the previous phase (Feature #4). It handles 13 operation types: embedding, position embedding, layer norm, RMS norm, matmul, GELU, SiLU, softmax, add, mul, scaled dot-product attention, transpose, and reshape. Every operation dispatches through our dispatch.zig module, which routes to CPU or GPU based on the engine assignment.

Engine 3 — npu_ane — was defined in the enum but never dispatched. The comment on line 18 of dispatch.zig said it plainly: "ANE (not yet wired for inference)."

Feature #5 changes that.


III – What ANE Requires That GPU Doesn't

GPU dispatch through Metal shared memory is conceptually simple. Your data is already in the right format (row-major f32 arrays). The buffer is already accessible to both CPU and GPU (unified memory). You call the Metal compute kernel, sync, and read the result.

ANE dispatch requires four things that GPU doesn't:

Spatial Packing

ANE doesn't read row-major tensors. It expects data in a "spatial" format dictated by the ANE's internal tiling architecture. A [M, K] matrix and a [K, N] matrix must be packed into a single contiguous byte buffer using a specific layout that interleaves the data for the ANE's 16-wide SIMD lanes.

Our ANE kernels expose packInput() and unpackOutput() functions that handle this conversion:

// Pack [M,K] activation and [K,N] weight into ANE spatial format
kernel.packInput(pack_buffer, a_f32, b_f32);
kernel.writeInput(pack_buffer);
kernel.eval();
kernel.readOutput(unpack_buffer);
kernel.unpackOutput(unpack_buffer, out_f32);

This pack/unpack cycle is pure overhead that GPU dispatch doesn't have. For a 768x3072 matmul, the pack buffer is ~18MB. We pre-allocate these buffers at init time and reuse them for every dispatch.

Minimum Dimension Constraints

ANE has a hard minimum of 32 elements per dimension, derived from its 64-byte alignment requirement. If your matmul has M=1 (single-token decode), you can't just send it to ANE — you need to zero-pad to M=32, run the kernel, then trim the output back to 1 row.

This is a consequence of how ANE hardware tiles computation. The Neural Engine processes data in fixed-size tiles, and dimensions smaller than the tile size waste compute. For decode-phase inference where every forward pass processes exactly one token (M=1), this means 31/32 of the ANE's computation is wasted on padding.

The practical implication: ANE is most efficient for prefill (M=seq_len, often 128-2048) and least efficient for single-token decode. This is the opposite of what you'd naively expect — you'd think the 15 TFLOPS engine should handle every matmul. But the padding overhead at M=1 makes CPU-AMX competitive or faster for decode-phase matmuls with small sequence dimensions.

Channel Limits

ANE has a 32,768-element limit per channel dimension. A matmul with K > 32768 or N > 32768 simply cannot run on ANE in a single kernel invocation. For GPT-2 medium this isn't a problem — the vocabulary projection is 768 -> 50257, and 50257 exceeds the limit, but that's only one operation at the output layer. For the 90% of matmuls inside the transformer blocks (768x768, 768x3072, 3072x768), all dimensions are well within limits.

We handle this at compile time: during compileKernelsForSchedule(), we skip any operation with dimensions exceeding 32K and log a warning. The executor falls back to CPU dispatch for those shapes.

Kernel Compilation

GPU kernels (Metal compute shaders) are compiled once and cached by the Metal framework. ANE kernels must be compiled explicitly for each unique shape. A matmul with shape (1, 768, 768) is a different compiled kernel than (1, 768, 3072).

For GPT-2 medium, we identified 3–5 unique matmul shapes across the entire model. Compilation takes 10–50ms per kernel. We compile all kernels eagerly at executor init time rather than lazily during inference, because a 50ms latency spike on the first forward pass is unacceptable for interactive generation.

// At executor init: scan schedule, compile all ANE kernels upfront
ane_dispatcher.compileKernelsForSchedule(schedule.operations);

The kernel cache maps shape keys like "matmul_1x768x768" to compiled ANEMatmulKernel instances. If the same shape appears 24 times in a 12-layer model (once per attention projection, once per MLP), we compile it once and reuse it 24 times.


IV – Hardware Fences: The GPU↔ANE Synchronization Problem

When the HEFT scheduler produces a mixed-engine schedule — attention on GPU, MLP on ANE, or vice versa — the executor must ensure data coherence at engine transitions. GPU writes are not instantly visible to ANE, and ANE writes are not instantly visible to GPU.

The synchronization mechanism is MTLSharedEvent, Apple's cross-engine signaling primitive. It's the same mechanism used internally by CoreML when it orchestrates multi-engine execution plans. We wrap it in a HardwareFence:

// GPU → ANE transition
dispatch_ctx.syncGpu();            // Flush pending GPU writes
fence.signalFromGpu(dispatch_ctx); // Signal MTLSharedEvent
// ... ANE side ...
fence.waitForGpu();                // Block until GPU signal arrives
// Now safe to read GPU output on ANE

The reverse direction (ANE → GPU) uses signalFromAne() / waitForAne().

Measured fence latency on M2 Max: ~0.1ms per transition. This is fast enough that even a schedule with 24 GPU↔ANE transitions per forward pass (one per layer, both directions) adds only ~4.8ms of synchronization overhead.

sequenceDiagram
    participant GPU as GPU (Metal)
    participant Fence as MTLSharedEvent
    participant ANE as ANE

    Note over GPU,ANE: GPU → ANE transition
    GPU->>GPU: complete pending writes
    GPU->>Fence: syncGpu() + signalFromGpu()
    Fence-->>ANE: waitForGpu() unblocks
    ANE->>ANE: safe to read GPU output

    Note over GPU,ANE: ANE → GPU transition
    ANE->>ANE: complete eval()
    ANE->>Fence: signalFromAne()
    Fence-->>GPU: waitForAne() unblocks
    GPU->>GPU: safe to read ANE output

Compare this to the alternative — explicit memory copies between GPU and ANE address spaces — which would cost 0.5–1ms per copy for typical activation sizes. The hardware fence approach is 5-10x faster because the data never moves; only the synchronization signal travels.

When Fences Insert

The executor's runForwardPass() tracks the previous operation's engine. On any transition involving npu_ane, it calls the appropriate fence methods:

Transition Action
gpu_metalnpu_ane syncGpu() + signalFromGpu() + waitForGpu()
npu_anegpu_metal signalFromAne() + waitForAne()
npu_anecpu_amx signalFromAne() + waitForAne()
cpu_amxnpu_ane No fence needed (CPU writes are immediately coherent)

That last row is important. CPU writes to unified memory are immediately visible because there's no write buffer or cache coherence delay on the CPU→ANE path. Only GPU writes require explicit synchronization because the GPU has its own command buffer pipeline.


V – Fused Kernels: Why 3 Ops > 3 Separate Kernels

Individual ANE matmul dispatch achieves roughly 30% of the ANE's theoretical throughput. The overhead is dominated by the pack/unpack cycle and kernel launch latency, not the actual computation.

The solution is kernel fusion. Instead of dispatching three separate operations for an MLP block (matmul → GELU → matmul), we dispatch a single ANEMLPKernel that fuses all three:

// Without fusion: 3 kernel launches, 3 pack/unpack cycles
dispatchMatmul(x, W1, intermediate, M, K, N1);   // ~0.15ms
dispatchGelu(intermediate, activated, N1);         // ~0.08ms
dispatchMatmul(activated, W2, output, M, N1, N2); // ~0.15ms
// Total: ~0.38ms, 3 pack/unpack cycles

// With fusion: 1 kernel launch, 1 pack/unpack cycle
dispatchMLP(x, W1, W2, bias, output, d_model, d_ff, seq_len); // ~0.20ms
// Total: ~0.20ms, 1 pack/unpack cycle

The fused kernel runs the entire MLP computation on-chip within the ANE's SRAM, avoiding intermediate writes to main memory. This pushes utilization from ~30% to ~45% for MLP blocks.

We support two fusion patterns:

Pattern Ops Consumed Kernel Use Case
MLP matmul → gelu → matmul ANEMLPKernel GPT-2, BERT-style FFN
SwiGLU matmul → silu → mul → matmul ANESwiGLUKernel Llama, Mistral, Qwen FFN

Fusion detection happens in the executor's forward pass loop. Before dispatching each operation, we call detectFusion() which performs a lookahead over the next 3–4 operations:

// In runForwardPass, before dispatching op[i]:
if (ane_dispatcher.detectFusion(operations, i)) |fusion| {
    // Dispatch fused kernel, advance loop index past consumed ops
    switch (fusion.pattern_type) {
        .mlp => ane_dispatcher.dispatchMLP(...),
        .swiglu => ane_dispatcher.dispatchSwiGLU(...),
    }
    i += fusion.op_count;
    continue;
}

The critical constraint: all operations in a fusion pattern must be assigned to engine 3 (ANE). If the scheduler places the GELU on CPU, fusion is impossible and individual dispatch proceeds. This means the HEFT scheduler's engine assignments directly affect ANE utilization — a scheduler-aware of fusion patterns will group MLP ops onto ANE together, while a naive scheduler might split them across engines.

flowchart TD
    A["op assigned to ANE?"] -- No --> B["Dispatch individually"]
    A -- Yes --> C["detectFusion() lookahead"]
    C -- "matmul→gelu→matmul" --> D["dispatchMLP()"]
    C -- "matmul→silu→mul→matmul" --> E["dispatchSwiGLU()"]
    C -- "no pattern" --> F["Dispatch single ANE op"]
    D --> G["advance i by 3"]
    E --> H["advance i by 4"]
    F --> I["advance i by 1"]
    B --> I

VI – The Fallback Guarantee

ANE dispatch is optional by design. Every code path that attempts ANE dispatch has a CPU fallback:

if (engine == .npu_ane and ane_dispatcher.isAvailable() and !config.deterministic) {
    if (ane_dispatcher.dispatchMatmul(a, b, out, M, K, N)) return;
    // ANE dispatch returned false — fall through to CPU
}
// CPU dispatch (always works)
dispatch.matmul(&ctx, .cpu_amx, &input_buf, weight, &output_buf, a_elems, c_elems, M, K, N);

There are five conditions that trigger fallback:

  1. ANE unavailablezml_ane_init() failed (no ANE hardware, driver issue). The ANEDispatcher initializes with ane_ctx = null and isAvailable() returns false. Not an error — it's expected on Intel Macs or in CI.
  2. Deterministic mode--deterministic flag routes all ANE ops to CPU. ANE fp16 arithmetic is not bit-reproducible across runs due to hardware scheduling non-determinism. Our constitution (Principle IX) requires deterministic debugging support.
  3. Dimension limits — K > 32768 or N > 32768. Detected at compile time, those shapes are never compiled and dispatch returns false immediately.
  4. Compilation failure — ANE kernel compilation failed for a specific shape. The kernel cache has no entry, dispatch returns false.
  5. Runtime eval failurekernel.eval() returned an error. Rare, but possible under memory pressure. Caught and logged, falls through to CPU.

The dispatch functions return bool instead of error unions deliberately. On the hot path of inference, we don't want error handling overhead — just a branch. True means ANE handled it. False means try CPU.


VII – The SRAM Budget Problem

ANE has 32MB of on-chip SRAM. When a kernel's working set fits within SRAM, you get the full 15-18 TFLOPS. When it exceeds SRAM, the ANE must spill to main memory, and throughput drops by approximately 30%.

For GPT-2 medium with d_model=1024 and d_ff=4096, the MLP weights alone are 1024 * 4096 * 2 (fp16) = 8MB per direction, 16MB total. Add the activation buffers and you're at ~20MB — safely within the 32MB budget.

For Qwen3.5-0.8B with d_model=1536 and d_ff=8960, the MLP weights are 1536 * 8960 * 2 = ~26MB per direction. This exceeds SRAM.

We don't solve SRAM overflow in this phase — that's a future optimization involving weight tiling and double-buffered compilation (REC-001.7 in our roadmap). For now, the HEFT scheduler's calibration data accounts for SRAM spill: if ANE calibration shows degraded throughput for a particular shape, the scheduler assigns it to GPU instead. The executor respects whatever the scheduler decided.


VIII – What We Couldn't ANE: Attention

Scaled dot-product attention cannot run on ANE. The reason is causal masking.

During autoregressive generation, each query position can only attend to itself and all previous positions. This requires applying a triangular mask to the attention scores before softmax. ANE's fixed-function units don't support conditional masking — they execute the same operation on every element.

There's an exception: at decode time with seq_len=1, there's no masking needed (every position is "previous" relative to the single current position). The ANETransformerBlockKernel in our codebase supports this case, fusing the entire transformer block (RMSNorm + Attention + RMSNorm + SwiGLU) into a single kernel with ~94% ANE utilization. But that's a Phase 2 optimization (REC-001.4) — for this phase, attention remains on CPU where we implement the full SDPA decomposition with causal masking and KV cache.

This creates a natural engine split: attention on CPU/GPU, everything else on ANE. The HEFT scheduler already models this — attention operations have ANE cost set to infinity in calibration, so they're never assigned to engine 3.

flowchart LR
    subgraph ANE1 ["ANE"]
        A["RMS Norm"] --> B["Q/K/V Projection"]
    end

    subgraph CPUGPU ["CPU / GPU"]
        C["Scaled Dot-Product Attention\n+ causal mask + KV cache"]
    end

    subgraph ANE2 ["ANE"]
        D["Output Projection"]
        E["RMS Norm"] --> F["SwiGLU / MLP"]
    end

    B -- "fence" --> C
    C -- "fence" --> D
    D --> E

IX – Architecture Summary

Here's the complete dispatch architecture after wiring ANE:

schedule.json
     │
     ▼
┌─────────────────────────────────────────────────┐
│ Executor.runForwardPass()                        │
│                                                  │
│  for each operation:                             │
│    ├── detect fusion (lookahead 3-4 ops)         │
│    ├── check engine transition → fence if needed │
│    └── dispatch:                                 │
│         ├── engine 0,1: dispatch.zig (CPU)       │
│         ├── engine 2:   dispatch.zig (GPU/Metal) │
│         └── engine 3:   ane_dispatcher.zig (ANE) │
│              ├── kernel cache lookup             │
│              ├── packInput → spatial format      │
│              ├── eval on ANE hardware             │
│              └── unpackOutput → f32              │
│              └── fallback → CPU if any step fails│
└─────────────────────────────────────────────────┘

The ANEDispatcher module (ane_dispatcher.zig) owns:

  • ANE context lifecycle
  • Kernel compilation cache (keyed by "{op}_{M}x{K}x{N}")
  • Pre-allocated pack/unpack byte buffers
  • Hardware fence (MTLSharedEvent wrapper)
  • Fusion detection (MLP, SwiGLU patterns)
  • Verbose logging for all ANE events

The executor owns the orchestration: forward pass loop, engine transition detection, fusion-aware dispatch, and the CPU fallback path.


X – Numbers We're Targeting

We can't run end-to-end benchmarks yet due to a known Zig 0.15.2 + macOS 26.4 SDK linker incompatibility (the entire codebase compiles but can't link against Apple frameworks on Tahoe). But based on individual ANE kernel benchmarks from our npu_ane.zig module:

Operation Shape ANE Latency CPU-AMX Latency Speedup
Matmul 1 x 768 x 768 ~0.08ms ~0.12ms 1.5x
Matmul 128 x 768 x 3072 ~0.35ms ~1.2ms 3.4x
Fused MLP 768 x 3072 x 1 ~0.20ms ~0.45ms 2.3x
Fused SwiGLU 1024 x 2816 x 1 ~0.25ms ~0.55ms 2.2x

The prefill phase (seq_len=128) benefits most from ANE due to larger M dimensions. The decode phase (seq_len=1, M=1 padded to 32) benefits less due to padding waste.

For a full GPT-2 medium forward pass with 12 layers, each containing 4 matmuls (Q/K/V projection, output projection) plus the MLP block, the ANE path for MLP blocks alone should save approximately 3ms per forward pass compared to all-CPU execution. At 128 tokens generated, that's ~384ms saved — meaningful for interactive generation.

The real payoff comes in Phase 2 when we add the ANETransformerBlockKernel fusion (REC-001.4) and double-buffered compilation (REC-001.7). At 94% ANE utilization with full-block fusion, the theoretical improvement is 4-6x over CPU for the entire non-attention portion of the model.


XI – What This Phase Taught Us

ANE is not a drop-in accelerator. You can't just send tensors to ANE the way you send them to a GPU. Every dispatch requires format conversion (spatial packing), every shape needs an explicitly compiled kernel, and every cross-engine transition needs hardware synchronization. The engineering cost is real.

Fallback-first design is non-negotiable. We implemented ANE dispatch as a bool-returning function that the executor can ignore. This made development dramatically easier — we could wire up the dispatch path without worrying about breaking existing CPU/GPU inference. If ANE fails for any reason, inference continues at full correctness on CPU.

Kernel fusion is where the ROI lives. Individual ANE matmul dispatch gains 1.5-3.4x over CPU depending on shape. Fused MLP dispatch gains less in raw speedup but eliminates 2/3 of the pack/unpack overhead, which is the real bottleneck at small batch sizes. The lesson: don't benchmark single operations and extrapolate. Benchmark the actual dispatch pattern including all overhead.

The HEFT scheduler is the real decision maker. The executor doesn't decide what runs on ANE — the scheduler does, based on calibration data that accounts for all the overhead (packing, fencing, SRAM limits). Our executor just respects the assignment and provides a reliable fallback. This separation of concerns is what makes the architecture composable: better calibration data automatically improves ANE utilization without changing the executor code.


This is part of the zml-flow inference runtime series. Next up: MPS transposed matmul for GPU (REC-001.3) and E5 MIL code generation for full ANE block fusion (REC-001.4).

– Antonio

"Simplicity is the ultimate sophistication."