Use the JAX backend

This guide explains how to use the JAX backend in Meridian.

Introduction to the JAX backend

By default, Meridian uses TensorFlow for its core numerical operations and probabilistic Markov Chain Monte Carlo (MCMC) sampling, providing a robust and thoroughly tested foundation for all modeling tasks.

For projects that benefit from enhanced performance and memory efficiency, Meridian provides the JAX backend. JAX encourages a functional programming style and utilizes XLA (Accelerated Linear Algebra) compilation to offer advanced performance optimizations.

Tutorial: To see JAX in action, see the Getting started with JAX notebook.

How to enable JAX

Because the core mathematical libraries are loaded at initialization, you must instruct Meridian to use JAX before importing any Meridian modules.

To enable JAX, set the MERIDIAN_BACKEND environment variable to 'jax'.

You must set this environment variable in your script before executing any import meridian statements:

import os

# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load

Enable 64-bit precision

For models where convergence is difficult to achieve, using 64-bit precision with the JAX backend can provide improved numerical stability. While 64-bit precision offers better numerical stability, it comes with increased memory usage and slower computation times. Therefore, 32-bit precision remains the default for most use cases. To enable it, set the MERIDIAN_ENABLE_JAX_X64 environment variable to 'True' before importing Meridian.

import os

# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'

# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model

If an invalid string is provided to the MERIDIAN_BACKEND environment variable, Meridian will issue a RuntimeWarning and default back to standard TensorFlow execution. If a value other than 'True' or '1' is provided to the MERIDIAN_ENABLE_JAX_X64 environment variable, 64-bit precision is not enabled, and Meridian defaults to 32-bit precision.

Numerical differences and reproducibility

Because TensorFlow and JAX compile their computational graphs differently, you may observe minor numerical differences in your posterior estimates when switching to JAX using the same data and random seeds.

While posterior distributions might not be identical across backends, the differences are generally small and not statistically significant for business metrics such as ROI and budget allocation. This ensures that switching to the JAX backend maintains the integrity of your model's insights.

Performance considerations

Internal testing found JAX supercharged initial model runs, cutting average runtime by ~40% and memory usage by ~70%, compared to TensorFlow when using GPUs. JAX also streamlined model iterations, enabling 2x faster runtimes, 4x less memory usage, and uninterrupted workflows by eliminating the need for kernel restarts.

Because of the increased memory efficiency, you have more headroom to adjust computationally intensive parameters. For example, in Meridian.sample_posterior(), you might increase the unrolled_leapfrog_steps argument (e.g., from 1 to 5). This can accelerate convergence by increasing the trajectory length of the No-U-Turn-Sampler (NUTS) without exceeding hardware memory limits. You can also increase the n_adapt parameter to further aid convergence during the adaptation phase.