Skip to content

Cactus Transpiler

The Cactus Transpiler takes a PyTorch model and turns it into a Cactus runtime graph that runs on ARM devices. You write normal PyTorch, and the transpiler handles the rest: capturing the computation graph, fusing ops, quantizing weights, and producing a bundle that loads with zero-copy memory mapping.

From PyTorch to On-Device in 3 Steps

Step 1: Write a PyTorch Model

Any torch.nn.Module works. Here's a small feedforward classifier:

import torch
import torch.nn as nn

class TextClassifier(nn.Module):
    def __init__(self, vocab_size=32000, embed_dim=256, num_classes=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc1 = nn.Linear(embed_dim, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, input_ids):
        x = self.embedding(input_ids)       # (batch, seq, embed_dim)
        x = x.mean(dim=1)                   # (batch, embed_dim)
        x = torch.nn.functional.relu(self.fc1(x))
        x = self.fc2(x)
        return x

model = TextClassifier()
torch.save(model.state_dict(), "classifier.pt")

Step 2: Convert Weights

Convert the model weights to Cactus format. This quantizes them and writes a weights_manifest.json that the transpiler uses to bind weights at load time:

cactus convert ./classifier.pt ./classifier-cactus --bits 4

Step 3: Transpile and Run

# Transpile: captures the graph, optimizes, and saves a bundle
cactus transpile ./classifier.pt --weights-dir ./classifier-cactus

# Run the transpiled bundle
cactus run ./classifier-cactus

That's it. The output is a self-contained bundle with memory-mapped weights that loads in milliseconds on any ARM device.


How It Works

The transpiler is a six-stage compiler pipeline:

PyTorch nn.Module
      |
      v
+------------------+
| 1. Canonicalize  |  Wrap the model in a task-specific adapter
|    model         |  (causal LM, encoder, multimodal, etc.)
+--------+---------+
         |
         v
+------------------+
| 2. torch.export  |  Capture the forward pass as an FX graph
|                  |  (no Python control flow, tensor-only)
+--------+---------+
         |
         v
+------------------+
| 3. Import to IR  |  FX nodes -> canonical IRGraph
|                  |  Bind converted weight files from manifest
+--------+---------+
         |
         v
+------------------+
| 4. Canonicalize  |  Normalize ops, constant-fold, remove no-ops,
|    IR            |  insert precision casts, dead code elimination
+--------+---------+
         |
         v
+------------------+
| 5. Fuse patterns |  Collapse subgraphs into single ops:
|                  |  RMSNorm, RoPE, attention, MLP, LSTM, etc.
+--------+---------+
         |
         v
+------------------+
| 6. Lower to      |  IRNodes -> CactusGraph ops
|    CactusGraph   |  Constants -> mmap_weights() or inline tensors
+------------------+
         |
         v
   Saved bundle: graph.cactus + manifest.json + weights

Stage 1: Canonicalize the Model

The transpiler doesn't try to handle every model's forward signature as-is. It wraps the model in a thin adapter that exposes a stable, task-specific interface:

Task Input Output
causal_lm_logits input_ids next-token logits
multimodal_causal_lm_logits input_ids, image/audio features next-token logits
encoder_hidden_states input_features hidden states
ctc_logits input_features CTC logits
tdt_transcription input_features TDT token IDs

This is where model-family-specific knowledge lives (Gemma4, Qwen, LFM2, Parakeet, etc.). The adapter also injects use_cache=False and return_dict=False so the export boundary stays tensor-only.

Stage 2: Capture with torch.export

The adapted model is captured using torch.export, producing an FX graph where every operation is a canonical ATen function call (e.g. aten.linear.default, aten.silu.default).

This is the generality boundary. After capture, the transpiler works with ATen op names, not layer names. It doesn't care whether your model calls its attention layer self_attn or mha — it recognizes the computation pattern.

Stage 3: Import to IR

The FX graph is converted into Cactus's intermediate representation (IRGraph), a simple dataflow graph of IRNodes connected by IRValues.

During import, constants that correspond to converted weight files are resolved through weights_manifest.json. This metadata follows the value through the pipeline so that lowering can use mmap_weights() instead of embedding the tensor into the graph.

Stage 4: Canonicalize the IR

Multi-pass cleanup:

  • Rename ops to canonical spellings
  • Rewrite transpose/movedim into permute
  • Constant-fold trivial expressions
  • Materialize ones, zeros, arange constants
  • Remove no-op casts, views, slices
  • Insert FP16 precision casts where the runtime requires them
  • Dead code elimination

Stage 5: Fuse Patterns

The optimizer recognizes computational subgraphs and collapses them into single high-level ops:

Pattern Fused Op
variance -> rsqrt -> mul -> mul rms_norm
cos/sin interleave on position freqs rope
Q @ K^T -> scale -> mask -> softmax -> @ V attention
gate_proj * up_proj -> activation -> down_proj dense_mlp_tq_fused
conv -> batchnorm -> activation conv_module
input/forget/cell/output gates lstm_cell

These fused ops map directly to optimized Cactus kernels.

Stage 6: Lower to CactusGraph

Each IRNode becomes one or more CactusGraph operations. Constants become:

  • graph.mmap_weights(path) — zero-copy memory-mapped from disk (most weights)
  • Inline scalar values
  • Materialized tensor nodes (small constants baked into the graph)

The output is a TranspiledGraph containing the runtime graph, input/output handles, and weight binding metadata.


Transpiling HuggingFace Models

The most common use case is transpiling a model from HuggingFace:

# 1. Convert weights
cactus convert google/gemma-4-E2B-it ./gemma4-weights --bits 4

# 2. Transpile
cactus transpile google/gemma-4-E2B-it --weights-dir ./gemma4-weights

# 3. Run
cactus run ./gemma4-weights --prompt "Hello!"

Text Models

cactus convert Qwen/Qwen3-0.6B ./qwen3-weights --bits 4
cactus transpile Qwen/Qwen3-0.6B \
  --weights-dir ./qwen3-weights \
  --prompt "The capital of France is"

Multimodal Models (Gemma4)

Multimodal models are automatically split into components (vision_encoder, audio_encoder, lm_encoder, decoder) that are each independently captured and optimized:

cactus convert google/gemma-4-E2B-it ./gemma4-weights --bits 4
cactus transpile google/gemma-4-E2B-it \
  --task multimodal_causal_lm_logits \
  --image-file photo.jpg \
  --audio-file speech.wav \
  --prompt "Describe what you see and hear." \
  --weights-dir ./gemma4-weights

Speech Models (Parakeet TDT)

cactus convert nvidia/parakeet-tdt-0.6b-v3 ./parakeet-weights --bits 4
cactus transpile nvidia/parakeet-tdt-0.6b-v3 \
  --audio-file recording.wav \
  --task tdt_transcription \
  --weights-dir ./parakeet-weights

Running Saved Bundles

Transpiled bundles are saved to ./weights/<model>/ by default (alongside their CQ weights). Run them later without re-transpiling:

# Text
cactus run ./weights/qwen3-0.6b --prompt "Write a haiku"

# Multimodal
cactus run ./weights/gemma-4-e2b-it \
  --image photo.jpg \
  --audio speech.wav \
  --prompt "What do you see?"

# Audio
cactus run ./weights/parakeet-tdt-0.6b-v3 \
  --audio meeting.wav

cactus run accepts either a HuggingFace model id (downloads the bundle) or a local bundle directory path (detected by the presence of components/manifest.json). The standalone run-transpiled subcommand was removed — cactus run handles both cases.


Using the Python API Directly

You can also drive the transpiler from Python:

from cactus.transpile.capture_pytorch import capture_model
from cactus.transpile.lower import transpile_ir
from cactus.bindings.cactus import Graph
import torch
import numpy as np

# Define your model
class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = torch.nn.Linear(64, 32)

    def forward(self, x):
        return torch.relu(self.linear(x))

model = MyModel().eval()
example_input = torch.randn(1, 64)

# Capture and transpile
captured = capture_model(model, (example_input,))
transpiled = transpile_ir(captured.ir_graph)

# Run
transpiled.set_input(0, np.random.randn(1, 64).astype(np.float16))
outputs = transpiled.execute()
result = outputs[0].numpy()
print("Output shape:", result.shape)  # (1, 32)

Generic JAX User Graphs

JAX and Flax models can be bundled by providing params and one or more graph entrypoints. The caller owns tokenization, masks, sampling, and any prefill/decode loop; Cactus captures the supplied functions, writes FP16 mmap weights, and saves component graphs.

import jax.numpy as jnp
import numpy as np

from cactus.transpile.capture_jax import JaxGraphSpec
from cactus.transpile.jax_user_graph_bundle import build_jax_user_graph_bundle
from cactus.transpile.jax_user_graph_bundle import load_jax_user_graph_bundle


def model(params, x):
    return jnp.maximum(x @ params["w"] + params["b"], 0)


params = {"w": jnp.ones((4, 3), jnp.float16), "b": jnp.zeros((3,), jnp.float16)}
example_x = jnp.ones((2, 4), jnp.float16)

result = build_jax_user_graph_bundle(
    params=params,
    output_dir="/tmp/my-jax-bundle",
    model_id="my-jax-model",
    specs=(JaxGraphSpec(name="forward", fn=model, example_args=(example_x,)),),
)

loaded = load_jax_user_graph_bundle(result.output_dir)
y = loaded.execute("forward", np.ones((2, 4), np.float16))[0].numpy()

For encoder-decoder models, provide each graph boundary explicitly:

params, model, tokenizer, config = load_needle_model()
src = tokenizer.encode("What time is it in Tokyo?")
tgt = tokenizer.encode("<bos>")
src_mask = make_padding_mask(src, config.pad_token_id)
tgt_mask = make_causal_mask(tgt.shape[1])
encoder_out, enc_mask = model.apply({"params": params}, src, src_mask, method=model.encode_text)

result = build_jax_user_graph_bundle(
    params=params,
    output_dir="weights/needle",
    model_id="needle",
    task="encoder-decoder",
    specs=(
        JaxGraphSpec(
            name="encoder",
            fn=lambda params, src, src_mask: model.apply(
                {"params": params}, src, src_mask, method=model.encode_text
            ),
            example_args=(src, src_mask),
            output_names=("encoder_out", "encoder_mask"),
        ),
        JaxGraphSpec(
            name="decoder_prefill",
            fn=lambda params, tgt, encoder_out, tgt_mask, cross_mask: model.apply(
                {"params": params}, tgt, encoder_out, tgt_mask, cross_mask, method=model.decode
            ),
            example_args=(tgt, encoder_out, tgt_mask, enc_mask),
            output_names=("logits",),
        ),
    ),
)

loaded = load_jax_user_graph_bundle(result.output_dir)
encoder_out, enc_mask = loaded.execute("encoder", src, src_mask)
logits = loaded.execute("decoder_prefill", tgt, encoder_out, tgt_mask, enc_mask)[0].numpy()
next_token = int(np.argmax(logits[0, -1]))

Each component is saved under components/<name>/ with graph.cactus, raw_ir.json, and optimized_ir.json. For decode, expose a separate decoder_step graph. The generic JAX path does not create or manage internal KV cache yet; component output Tensors can be passed directly into another component. Use loaded.reset() between independent generations.


Artifact Layout

A transpiled bundle looks like this:

weights/<model>/
  raw_ir.json              # IR before optimization (debugging)
  optimized_ir.json        # IR after fusion passes (debugging)
  graph.cactus             # serialized runtime graph
  graph_bindings.json      # weight binding metadata
  result.json              # execution results (if --execute-after-transpile)
  components/              # for multimodal models
    manifest.json          # component order, input/output names
    vision_encoder/
      graph.cactus
      bound_constants/
    audio_encoder/
      graph.cactus
    lm_encoder/
      graph.cactus
    decoder/
      graph.cactus

The manifest.json tells the runtime loader how to chain the components and where to find the memory-mapped weight files. The raw/optimized IR JSONs are also valid reload targets — the runtime can re-lower from saved IR as a fallback.


CLI Reference

cactus transpile

cactus transpile <model-id-or-path> [options]
Option Description
--weights-dir <path> Path to converted CQ weights (default: weights/<model_name>)
--task <name> Force task type (default: auto — inferred from weights). Choices: causal_lm_logits, multimodal_causal_lm_logits, ctc_logits, encoder_hidden_states, seq2seq_transcription, tdt_transcription
--prompt <text> Representative prompt for shape capture
--image-file <path> Representative image (repeatable)
--audio-file <path> Representative audio file
--max-new-tokens <N> Generation room to preallocate for causal decode graphs
--artifact-dir <path> Output directory (default: weights/<model>)
--execute-after-transpile Run the graph after saving
--skip-reference-compare Skip PyTorch vs transpiled comparison
--component-pipeline auto\|on\|off Use split component graph transpilation when supported
--components <list> Comma-separated component subset for component-pipeline models
--trust-remote-code Allow HF remote code during the transpile phase
--local-files-only Require HF files to already be local
--allow-unconverted-weights Transpile against an unconverted source checkpoint
--no-fuse-rms-norm Disable RMSNorm fusion
--no-fuse-rope Disable RoPE fusion
--no-fuse-attention Disable attention fusion

NPU emission (CoreML .mlpackages for Apple Silicon audio and vision encoders) is available through the --npu, --npu-quantize, --npu-audio-quantize, and --npu-vision-quantize flags on cactus transpile.

cactus run

cactus run <model-id-or-bundle-path> [options]

model_id may be a HuggingFace model id (downloads the matching bundle from huggingface.co/Cactus-Compute) or a local path to a bundle directory.

Option Description
--bits 1\|2\|3\|4 CQ quantization bits when downloading (default: 4)
--platform cpu\|apple Target accelerator (default: cpu — generic ARM)
--token <token> HuggingFace token (gated models)
--prompt <text> Input prompt
--input-ids <ids> Comma-separated token IDs
--image <path> Image file for multimodal bundles
--audio <path> Audio file for speech / audio bundles
--system <prompt> System prompt
--thinking Enable thinking / reasoning mode
--max-new-tokens <N> Max tokens to generate
--result-json <path> Save bundle result payload to JSON
--no-cloud-handoff Disable automatic cloud handoff
--confidence-threshold <f> Confidence threshold below which to hand off to cloud
--cloud-timeout-ms <N> Max wait time for cloud handoff
--no-cloud-tele Disable cloud telemetry (write to cache only)
--reconvert Force local convert+transpile fallback

Supported Model Families

The transpiler works with any model that torch.export can capture, but has tested adapters for:

Family Tasks Component Split
Gemma 3/4 text, multimodal yes (vision + audio + LM + decoder)
Qwen 3/3.5 text, vision no
LFM2/2.5 text, vision no
Whisper encoder no
Parakeet CTC encoder no
Parakeet TDT transcription yes (encoder + decoder)

Adding a new model family means writing a small adapter in model_adapters.py that wraps the HF forward into the canonical task interface. The rest of the pipeline (capture, optimize, lower) is model-agnostic.


Design Decisions

ATen ops, not layer names. After torch.export, the transpiler dispatches on canonical ATen operation names (aten.linear.default -> linear). It never reads HuggingFace module paths like model.layers.0.self_attn. This means the fusion passes work on any model that produces the same computational pattern, regardless of how the code is organized.

Weight binding through manifests. The transpiler resolves weights exclusively through weights_manifest.json written by cactus convert. It does not guess filenames from layer names. This is intentional — it keeps the contract between conversion and transpilation explicit and auditable.

Component splitting for multimodal. Complex models like Gemma4 are split into independent components (vision encoder, audio encoder, LM encoder, decoder) before capture. Each component goes through the full optimize/lower pipeline independently. This keeps each graph small enough for torch.export to handle and lets the runtime load/unload components independently.

Saved IR as fallback. Bundles include both the serialized .cactus graph and the raw/optimized IR JSON. If the binary graph format changes between versions, the runtime can re-lower from saved IR instead of requiring a full re-transpile.

See Also