Cómo usar el backend de JAX

En esta guía, se explica cómo usar el backend de JAX en Meridian.

Introducción al backend de JAX

De forma predeterminada, Meridian usa TensorFlow para sus operaciones numéricas principales y el muestreo probabilístico de Monte Carlo basado en cadenas de Markov (MCMC), lo que proporciona una base sólida y exhaustivamente probada para todas las tareas de modelado.

Para los proyectos que se benefician de un rendimiento mejorado y una mayor eficiencia de la memoria, Meridian proporciona el backend de JAX. JAX fomenta un estilo de programación funcional y utiliza la compilación de XLA (álgebra lineal acelerada) para ofrecer optimizaciones avanzadas del rendimiento.

Instructivo: Para ver JAX en acción, consulta el notebook Primeros pasos con JAX.

Cómo habilitar JAX

Debido a que las bibliotecas matemáticas principales se cargan durante la inicialización, debes indicarle a Meridian que use JAX antes de importar cualquier módulo de Meridian.

Para habilitar JAX, configura la variable de entorno MERIDIAN_BACKEND como 'jax'.

Debes establecer esta variable de entorno en tu secuencia de comandos antes de ejecutar cualquier sentencia import meridian:

import os

# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load

Cómo habilitar la precisión de 64 bits

En el caso de los modelos en los que es difícil lograr la convergencia, usar una precisión de 64 bits con el backend de JAX puede proporcionar una mayor estabilidad numérica. Si bien la precisión de 64 bits ofrece una mejor estabilidad numérica, conlleva un mayor uso de memoria y tiempos de procesamiento más lentos. Por lo tanto, la precisión de 32 bits sigue siendo la predeterminada para la mayoría de los casos de uso. Para habilitarla, configura la variable de entorno MERIDIAN_ENABLE_JAX_X64 en 'True' antes de importar Meridian.

import os

# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'

# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'

# Now it is safe to import Meridian modules
from meridian.model import model

Si se proporciona una cadena no válida a la variable de entorno MERIDIAN_BACKEND, Meridian emitirá un RuntimeWarning y volverá a la ejecución estándar de TensorFlow. Si se proporciona un valor distinto de "True" o "1" para la variable de entorno MERIDIAN_ENABLE_JAX_X64, no se habilita la precisión de 64 bits y Meridian usa la precisión de 32 bits de forma predeterminada.

Diferencias numéricas y reproducibilidad

Dado que TensorFlow y JAX compilan sus grafos de procesamiento de manera diferente, es posible que observes pequeñas diferencias numéricas en tus estimaciones a posteriori cuando cambies a JAX con los mismos datos y las mismas semillas aleatorias.

Si bien las distribuciones a posteriori podrían no ser idénticas en todos los backends, las diferencias suelen ser pequeñas y no tienen importancia estadística para las métricas comerciales, como el ROI y la asignación del presupuesto. Esto garantiza que el cambio al backend de JAX mantenga la integridad de las estadísticas de tu modelo.

Consideraciones de rendimiento

Las pruebas internas revelaron que JAX potenció las ejecuciones iniciales del modelo, lo que redujo el tiempo de ejecución promedio en un 40% y el uso de memoria en un 70%, en comparación con TensorFlow cuando se usan GPU. JAX también optimizó las iteraciones del modelo, lo que permitió tiempos de ejecución 2 veces más rápidos, un uso de memoria 4 veces menor y flujos de trabajo ininterrumpidos, ya que se eliminó la necesidad de reiniciar el kernel.

Gracias a la mayor eficiencia de la memoria, tienes más margen para ajustar los parámetros que requieren una gran cantidad de procesamiento. Por ejemplo, en Meridian.sample_posterior(), puedes aumentar el argumento unrolled_leapfrog_steps (p. ej., de 1 a 5). Esto puede acelerar la convergencia, ya que aumenta la longitud de la trayectoria del No-U-Turn-Sampler (NUTS) sin exceder los límites de memoria del hardware. También puedes aumentar el parámetro n_adapt para ayudar aún más a la convergencia durante la fase de adaptación.