Section 17.4: MJX/Brax-training and JAX RL

A Careful Control Loop
Technical illustration with a compiler assembling many robot simulations into one accelerator-shaped batch, illustrating JAX vectorization, static shapes, and device-resident MJX or Brax rollouts.
Figure 17.4A: MJX and Brax reward the builder who thinks in batched arrays first: fixed shapes, explicit random keys, and rollout buffers sized before the first compile.
Big Picture

MJX, Brax, and JAX RL move the entire reinforcement-learning loop toward accelerator-native arrays. The simulator step, policy inference, advantage computation, and update can all be expressed as compiled transformations over batched state.

For MJX/Brax-training and JAX RL, GPU RL depends on simulator fidelity, PPO rollout semantics, reward terms, and reset distribution being versioned in the same training artifact.

This section develops the JAX-native training contract for MJX and Brax. We focus on vectorized environments, jit boundaries, vmap or batched stepping, lax.scan rollout collection, and explicit PRNG splitting.

The key question is practical: can the rollout loop be compiled once, run many times with fixed tensor shapes, and still produce independent training and evaluation episodes?

Compile Once, Vary By Data

JAX rewards programs whose control flow and tensor shapes stay stable. Vary terrain, commands, and seeds through arrays, not through Python branches that force recompilation.

Theory

MJX provides MuJoCo-style physics through JAX arrays, while Brax provides a JAX-native physics and RL stack designed for massively parallel simulation. In both cases, the important unit is not an object-oriented environment instance; it is a batched simulator state transformed by pure functions.

A typical rollout uses a compiled function that maps policy parameters, simulator state, and random keys to new state and experience tensors. The builder must know which dimensions are static, which arrays are donated or reused, and how many bytes the rollout buffer consumes before training begins.

Mechanism

The mechanism is functional batching: split a PRNG key into per-environment keys, step a batch of simulator states, write observations and rewards into a fixed-shape buffer, and use compiled array operations for the update. Debugging starts by printing shapes, dtypes, memory footprint, and key counts.

Worked Example

Code Fragment 17.4.1 estimates the rollout buffer before writing a JAX training loop. This is the first sanity check for whether a planned MJX or Brax experiment fits on the accelerator.

# Estimate device memory for a JAX-native rollout buffer.
# Static shapes should be chosen before jit compilation and training launch.
num_envs = 8192
horizon = 32
obs_dim = 64
act_dim = 12
float_bytes = 4

buffers = {
    "observations": horizon * num_envs * obs_dim * float_bytes,
    "actions": horizon * num_envs * act_dim * float_bytes,
    "rewards": horizon * num_envs * float_bytes,
    "dones": horizon * num_envs * float_bytes,
    "values": horizon * num_envs * float_bytes,
}

total_mb = sum(buffers.values()) / 1_000_000
for name, size in buffers.items():
    print(f"{name}: {size / 1_000_000:.1f} MB")
print(f"rollout buffer total: {total_mb:.1f} MB")
observations: 67.1 MB actions: 12.6 MB rewards: 1.0 MB dones: 1.0 MB values: 1.0 MB rollout buffer total: 82.8 MB
Code Fragment 17.4.1 sizes a fixed-shape rollout buffer for an 8,192-environment JAX run. The observation tensor dominates memory, so adding history, pixels, or privileged critic state should be treated as a memory decision before compilation.

Expected output: the trace should expose which tensor dominates memory. A JAX run that recompiles or spills buffers because shapes were guessed late will lose the throughput advantage this section is trying to teach.

Library Shortcut

Brax and MJX already provide the simulator side of the accelerator-native shortcut. The builder's job is to keep the surrounding RL code compatible with that shortcut: no per-environment Python loops, no shape changes inside the hot path, and no hidden transfer from device to host during rollout collection.

Practical Recipe

  1. Choose static shapes for environment count, horizon, observation groups, action dimension, and privileged critic state.
  2. Split PRNG keys per environment and per rollout step so randomization does not synchronize.
  3. Put reset logic into compiled array operations where possible.
  4. Profile compile time separately from steady-state training throughput.
  5. Evaluate on a separate key stream and a separate task panel, then save both key roots.
Common Failure Mode

The common mistake is to benchmark after the first call and include compilation time in one result but not another. Another common mistake is to change batch shape during curriculum updates, causing silent recompilation and confusing wall-clock comparisons.

Practical Example

A team building an MJX quadruped task should freeze the model topology, observation shape, and horizon during a benchmark. Terrain difficulty can change through arrays and masks, while the compiled rollout shape remains constant.

Fun Note

JAX is happiest when the experiment arrives wearing the same tensor shape every day. Surprise it with a new shape mid-run, and the compiler gets a vote.

Research Frontier

MJX and Brax are part of a broader move toward robot-learning stacks where physics, policy inference, and learning live inside one accelerator program. The frontier question is how much task richness, especially contact complexity and pixel observations, can be added before compilation, memory, or transfer costs dominate the learning loop.

Self Check

Can you name the static rollout shape, memory footprint, PRNG key schedule, compile boundary, and held-out evaluation key stream? If not, the MJX or Brax run is not yet reproducible.

The idea in this section becomes useful when accelerator constraints are treated as part of the algorithm. A JAX RL loop is fast because it exposes the whole rollout as a regular computation graph. That same regularity means dynamic episode bookkeeping, variable-size observations, and host-side logging must be designed carefully.

The graduate-level habit is to report compile time, steady-state steps per second, memory footprint, shape choices, and key schedule separately. Otherwise a result may confuse a better algorithm with a better compilation path.

Practical Tool Choices For This Section
Tool or LibraryRole in the TopicBuilder Advice
MJXMuJoCo-style models through JAX arraysUse it when you want MuJoCo modeling concepts with accelerator-friendly batched stepping.
BraxJAX-native batched physics and RL environmentsUse it when fast parallel simulation and functional training loops are the main goal.
jax.jitCompile rollout and update functionsUse it around stable hot paths, and keep shape-changing logic outside benchmarks.
jax.vmap or batched stateApply one step function across many environmentsUse it to express environment parallelism as array structure rather than Python loops.
jax.random.splitIndependent randomness for environments and evaluationUse it to record train and evaluation key streams without accidental reuse.

A robust implementation starts with a JAX run manifest. It records the shapes that must stay static, the keys that drive randomness, and the metrics that separate compilation overhead from steady-state throughput.

  1. Freeze batch, horizon, observation, action, and privileged-state shapes before the first benchmark.
  2. Record root PRNG keys for training, curriculum randomization, and evaluation separately.
  3. Measure first-call compile time and post-compile steady-state throughput as different fields.
  4. Save memory estimates for rollout buffers and optimizer state.
  5. Compare algorithms only when they share the same static shapes and held-out key stream.
# Record the JAX-specific fields that make an MJX or Brax run auditable.
# Static shapes and separate PRNG roots prevent hidden recompilation and leakage.
from dataclasses import dataclass, asdict

@dataclass
class JaxRlManifest:
    simulator: str
    batch_shape: tuple[int, int]
    obs_dim: int
    action_dim: int
    train_key_root: int
    eval_key_root: int

    def as_row(self) -> dict[str, object]:
        return asdict(self)

manifest = JaxRlManifest(
    simulator="MJX",
    batch_shape=(32, 8192),
    obs_dim=64,
    action_dim=12,
    train_key_root=1704001,
    eval_key_root=1704999,
)
print(manifest.as_row())
{'simulator': 'MJX', 'batch_shape': (32, 8192), 'obs_dim': 64, 'action_dim': 12, 'train_key_root': 1704001, 'eval_key_root': 1704999}
Code Fragment 17.4.2 records the JAX fields that determine whether an MJX or Brax experiment is reproducible. The batch shape fixes the rollout compilation contract, while separate key roots prevent evaluation from reusing training randomness.

When a JAX-native run fails, separate numerical, compilation, and RL causes. Check NaNs, dtype choices, action clipping, reset masks, key reuse, recompilation counts, and host-device transfers before blaming the policy architecture.

Evaluation Recipe

For MJX and Brax experiments, compare only construct-matched metrics that are co-computed in one pass on one configuration: same batch shape, same simulator model, same policy checkpoint, same PRNG key panel, same perturbation suite, and the same success definition. Save compile time, steady-state steps per second, memory footprint, return, success, and failure labels as one artifact.

Key Takeaway

MJX and Brax make GPU RL effective when the experiment is expressed as stable batched array transformations. Static shapes, explicit PRNG keys, and device-resident buffers are part of the scientific method, not just implementation details.

Exercise 17.4.1

Design an MJX or Brax benchmark with 8,192 environments. Specify horizon, observation dimension, action dimension, rollout memory, train key root, evaluation key root, and how you would report compile time separately from steady-state throughput.

What's Next?

This section turned MJX and Brax training into a JAX-native contract: fixed shapes, explicit random keys, memory budgeting, compile boundaries, and held-out evaluation. Next, continue with Section 17.5, where privileged-information teachers use these fast simulators to train deployable students.

References & Further Reading
Foundational Papers, Tools, and Practice References

Makoviychuk, V. et al. (2021). Isaac Gym: High Performance GPU-Based Physics Simulation For Robot Learning. arXiv.

Isaac Gym is useful here as the GPU-resident predecessor to newer accelerator-native workflows. It frames why MJX and Brax also care about keeping simulation and learning close to the device.

Paper

Freeman, C. D. et al. (2021). Brax: A Differentiable Physics Engine for Large Scale Rigid Body Simulation. arXiv.

Brax is central to this section because it treats physics and RL as JAX-friendly batched computation. Use it to study functional environment stepping, vectorization, and accelerator-scale rollout design.

Paper

NVIDIA Isaac Lab documentation.

Isaac Lab offers a useful contrast to MJX and Brax. It emphasizes task and runner integration around Isaac Sim, while this section emphasizes compiled JAX loops and static rollout shapes.

Tool

Google DeepMind MuJoCo MJX documentation.

MJX is the primary source for MuJoCo-style physics inside JAX. It is the reference to consult for API boundaries, supported model features, and how batched data differs from classic MuJoCo usage.

Tool

Rudin, N. et al. (2022). Learning to Walk in Minutes Using Massively Parallel Deep Reinforcement Learning. CoRL.

Rudin et al. provide the locomotion workload that many JAX-native efforts try to accelerate or reproduce. Use it as a behavioral target, not as evidence that any new simulator setup transfers automatically.

Paper

RSL-RL repository.

RSL-RL is included as a non-JAX point of comparison for PPO storage and locomotion conventions. It helps readers distinguish algorithm settings from simulator execution model.

Tool