JAX バックエンドを使用する

このガイドでは、メリディアンで JAX バックエンドを使用する方法について説明します。

JAX バックエンドの概要

デフォルトでは、メリディアンはコアとなる数値演算と確率的なマルコフ連鎖モンテカルロ(MCMC)サンプリングに TensorFlow を使用し、すべてのモデリング タスクに堅牢で完全にテストされた基盤を提供します。

パフォーマンスとメモリ効率の向上を必要とするプロジェクトには、メリディアンが JAX バックエンドを提供します。JAX は関数型プログラミング スタイルを推奨し、XLA(Accelerated Linear Algebra)コンパイルを利用して高度なパフォーマンス最適化を実現します。

チュートリアル: JAX の動作を確認するには、JAX スタートガイドのノートブックをご覧ください。

JAX を有効にする方法

コアとなる数学ライブラリは初期化時に読み込まれるため、メリディアン モジュールをインポートする前に、JAX の使用をメリディアンに指示する必要があります。

JAX を有効にするには、環境変数 MERIDIAN_BACKEND'jax' に設定します。

import meridian ステートメントを実行する前に、スクリプトでこの環境変数を設定する必要があります。

import os

# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load

64 ビット精度を有効にする

収束が困難なモデルの場合、JAX バックエンドで 64 ビット精度を使用すると、数値の安定性が向上します。ただし、メモリ使用量が増えて計算時間が長くなるため、ほとんどのユースケースでは、デフォルトの 32 ビット精度のままになっています。64 ビット精度を有効にするには、メリディアンをインポートする前に MERIDIAN_ENABLE_JAX_X64 環境変数を 'True' に設定します。

import os

# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'

# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model

MERIDIAN_BACKEND 環境変数に無効な文字列が指定された場合、メリディアンはランタイム警告(RuntimeWarning)を発出し、デフォルトで標準の TensorFlow 実行に戻ります。MERIDIAN_ENABLE_JAX_X64 環境変数に「True」または「1」以外の値が指定された場合、64 ビット精度は有効にならず、メリディアンはデフォルトの 32 ビット精度を使用します。

数値の違いと再現性

TensorFlow と JAX では計算グラフのコンパイル方法が異なるため、同じデータと乱数シードを使用して JAX に切り替えると、事後推定値にわずかな数値の違いが生じる場合があります。

バックエンド間で事後分布が同一でない場合もありますが、その差は通常小さく、費用対効果や予算配分などのビジネス指標では統計的に有意ではありません。これにより、JAX バックエンドへの切り替え後も、モデルの分析情報の整合性が維持されます。

パフォーマンスに関する注意事項

内部テストでは、JAX によって初期モデルの実行速度が大幅に上昇し、GPU 使用時の TensorFlow と比較して、平均実行時間が約 40%、メモリ使用量が約 70% 削減されることがわかりました。また、JAX はモデルのイテレーションも効率化することで、実行時間を 2 倍高速化してメモリ使用量を 4 分の 1 に削減し、カーネルの再起動を不要にしてワークフローの中断をなくしました。

メモリ効率が向上したことで、計算負荷の高いパラメータを調整する余地が広がりました。たとえば、Meridian.sample_posterior() では、unrolled_leapfrog_steps 引数を引き上げることが可能です(例: 1 を 5 に)。これにより、ハードウェアのメモリ上限を超えずに No-U-Turn-Sampler(NUTS)の軌跡を長くして、収束を加速できます。また、適応フェーズでの収束をさらに促進できるよう、n_adapt パラメータの値を引き上げることも可能です。