Ce guide explique comment utiliser le backend JAX dans Meridian.
Présentation du backend JAX
Par défaut, Meridian utilise TensorFlow pour ses opérations numériques de base et son échantillonnage probabiliste MCMC (Monte-Carlo par chaînes de Markov), ce qui offre une base solide et minutieusement testée pour toutes les tâches de modélisation.
Pour les projets qui bénéficient de performances et d'une efficacité de mémoire améliorées, Meridian fournit le backend JAX. JAX encourage un style de programmation fonctionnel et utilise la compilation XLA (Accelerated Linear Algebra) pour des optimisations de performances avancées.
Tutoriel : pour voir JAX en action, consultez le notebook d'introduction à JAX.
Activer JAX
Étant donné que les bibliothèques mathématiques de base sont chargées lors de l'initialisation, vous devez indiquer à Meridian d'utiliser JAX avant d'importer des modules Meridian.
Pour activer JAX, définissez la variable d'environnement MERIDIAN_BACKEND sur 'jax'.
Vous devez définir cette variable d'environnement dans votre script avant d'exécuter des instructions import meridian :
import os
# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load
Activer la précision 64 bits
Pour les modèles où la convergence est difficile à obtenir, l'utilisation d'une précision 64 bits avec le backend JAX peut améliorer la stabilité numérique. Bien que la précision 64 bits offre une meilleure stabilité numérique, elle entraîne une utilisation accrue de la mémoire et des délais de calcul plus longs. Par conséquent, la précision 32 bits reste la valeur par défaut pour la plupart des cas d'utilisation. Pour l'activer, définissez la variable d'environnement MERIDIAN_ENABLE_JAX_X64 sur 'True' avant d'importer Meridian.
import os
# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'
# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
Si une chaîne non valide est fournie à la variable d'environnement MERIDIAN_BACKEND, Meridian émet un RuntimeWarning et revient à l'exécution standard de TensorFlow. Si une valeur autre que "True" ou "1" est fournie à la variable d'environnement MERIDIAN_ENABLE_JAX_X64, la précision 64 bits n'est pas activée et Meridian utilise par défaut la précision 32 bits.
Différences numériques et reproductibilité
Étant donné que TensorFlow et JAX compilent leurs graphes de calcul différemment, vous pouvez observer de légères différences numériques dans vos estimations a posteriori lorsque vous passez à JAX en utilisant les mêmes données et les mêmes graines aléatoires.
Bien que les distributions a posteriori puissent ne pas être identiques entre les backends, les différences sont généralement faibles et statistiquement non pertinentes pour les métriques commerciales telles que le ROI et la répartition du budget. Cela permet de s'assurer que le passage au backend JAX préserve l'intégrité des insights de votre modèle.
Considérations sur les performances
Les tests internes ont montré que JAX optimisait les exécutions initiales de modèles, réduisant le temps d'exécution moyen d'environ 40 % et l'utilisation de la mémoire d'environ 70 % par rapport à TensorFlow lors de l'utilisation de GPU. JAX a également simplifié les itérations de modèles, ce qui a permis de réduire de moitié les temps d'exécution, de diviser l'utilisation de la mémoire par quatre et de ne pas interrompre les workflows en éliminant le besoin de redémarrer le kernel.
Grâce à l'amélioration de l'efficacité de la mémoire, vous disposez d'une plus grande marge de manœuvre pour ajuster les paramètres gourmands en ressources de calcul. Ainsi, dans Meridian.sample_posterior(), vous pouvez augmenter l'argument unrolled_leapfrog_steps (par exemple, de 1 à 5). Cela peut accélérer la convergence en augmentant la longueur de la trajectoire de l'échantillonneur NUTS (No-U-Turn-Sampler) sans dépasser les limites de mémoire matérielle. Vous pouvez également augmenter le paramètre n_adapt pour faciliter davantage la convergence pendant la phase d'adaptation.