Este guia explica como usar o back-end do JAX no Meridian.
Introdução ao back-end do JAX
Por padrão, o Meridian usa o TensorFlow para as principais operações numéricas e a amostragem de Monte Carlo via cadeias de Markov (MCMC, na sigla em inglês) probabilística, oferecendo uma base robusta e totalmente testada para todas as tarefas de modelagem.
Para projetos que se beneficiam de desempenho e eficiência de memória aprimorados, o Meridian oferece o back-end do JAX. O JAX incentiva um estilo de programação funcional e usa a compilação de álgebra linear acelerada (XLA, na sigla em inglês) para oferecer otimizações avançadas de desempenho.
Tutorial: para ver o JAX em ação, consulte o notebook Introdução ao JAX.
Como ativar o JAX
As bibliotecas matemáticas principais são carregadas na inicialização. Por isso, você precisa instruir o Meridian a usar o JAX antes de importar qualquer módulo do Meridian.
Para ativar o JAX, defina a variável de ambiente MERIDIAN_BACKEND como 'jax'.
Defina essa variável de ambiente no script antes de executar qualquer instrução import meridian:
import os
# Enable JAX before importing Meridian
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
from meridian.data import load
Ativar a precisão de 64 bits
Para modelos em que a convergência é difícil de alcançar, usar precisão de 64 bits com o back-end do JAX pode melhorar a estabilidade numérica. Apesar disso, essa precisão aumenta o uso de memória e diminui os tempos de computação. Portanto, a precisão de 32 bits continua sendo o padrão para a maioria dos casos de uso. Para ativá-la, defina a variável de ambiente MERIDIAN_ENABLE_JAX_X64 como 'True' antes de importar o Meridian.
import os
# Enable JAX 64-bit precision
os.environ['MERIDIAN_ENABLE_JAX_X64'] = 'True'
# Enable JAX backend
os.environ['MERIDIAN_BACKEND'] = 'jax'
# Now it is safe to import Meridian modules
from meridian.model import model
Se uma string inválida for fornecida à variável de ambiente MERIDIAN_BACKEND, o Meridian vai emitir um aviso de tempo de execução e voltar à execução padrão do TensorFlow. Se um valor diferente de "True" ou "1" for fornecido à variável de ambiente MERIDIAN_ENABLE_JAX_X64, a precisão de 64 bits não será ativada, e o Meridian usará a precisão de 32 bits por padrão.
Diferenças numéricas e reprodutibilidade
Como o TensorFlow e o JAX compilam os grafos computacionais de maneira diferente, você pode observar pequenas diferenças numéricas nas estimativas a posteriori ao mudar para o JAX usando os mesmos dados e sementes aleatórias.
Embora as distribuições a posteriori não sejam idênticas em todos os back-ends, as diferenças geralmente são pequenas e não estatisticamente significativas para métricas de negócios como ROI e alocação de orçamento. Isso garante que a troca para o back-end do JAX mantenha a integridade dos insights do modelo.
Considerações sobre desempenho
Testes internos descobriram que o JAX impulsionou as execuções iniciais do modelo, reduzindo o tempo de execução médio em cerca de 40% e o uso da memória em cerca de 70%, em comparação com o TensorFlow ao usar GPUs. O JAX também simplificou as iterações do modelo, permitindo tempos de execução 2 vezes mais rápidos, uso de memória 4 vezes menor e fluxos de trabalho ininterruptos ao eliminar a necessidade de reinicializações do kernel.
Devido ao aumento da eficiência da memória, você tem mais espaço para ajustar parâmetros computacionalmente intensivos. Por exemplo, em Meridian.sample_posterior(), você pode aumentar o argumento unrolled_leapfrog_steps (por exemplo, de 1 para 5). Isso pode acelerar a convergência aumentando o comprimento da trajetória do No U Turn Sampler (NUTS) sem exceder os limites de memória do hardware. Também é possível aumentar o parâmetro n_adapt para ajudar ainda mais na convergência durante a fase de adaptação.