Converting NumPyro objects to DataTree#

DataTree is the data format ArviZ relies on.

This page covers multiple ways to generate a DataTree from NumPyro MCMC and SVI objects.

See also

We will start by importing the required packages and defining the model. The famous 8 school model.

import arviz_base as az
import numpy as np
from numpy.typing import ArrayLike

import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO, autoguide, Predictive
from jax import random
import jax.numpy as jnp
J = 8
y_obs = np.array([28.0, 8.0, -3.0, 7.0, -1.0, 1.0, 18.0, 12.0])
sigma = np.array([15.0, 10.0, 16.0, 11.0, 9.0, 11.0, 10.0, 18.0])
def eight_schools_model(J, sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    with numpyro.plate("J", J):
        eta = numpyro.sample("eta", dist.Normal(0, 1))
        theta = numpyro.deterministic("theta", mu + tau * eta)
        return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)
    

def eight_schools_custom_guide(J, sigma, y=None):

    # Variational parameters for mu
    mu_loc = numpyro.param("mu_loc", 0.0)
    mu_scale = numpyro.param("mu_scale", 1.0, constraint=dist.constraints.positive)
    mu = numpyro.sample("mu", dist.Normal(mu_loc, mu_scale))

    # Variational parameters for tau (positive support)
    tau_loc = numpyro.param("tau_loc", 1.0)
    tau_scale = numpyro.param("tau_scale", 0.5, constraint=dist.constraints.positive)
    tau = numpyro.sample("tau", dist.LogNormal(jnp.log(tau_loc), tau_scale))

    # Variational parameters for eta
    eta_loc = numpyro.param("eta_loc", jnp.zeros(J))
    eta_scale = numpyro.param("eta_scale", jnp.ones(J), constraint=dist.constraints.positive)
    with numpyro.plate("J", J):
        eta = numpyro.sample("eta", dist.Normal(eta_loc, eta_scale))

        # Deterministic transform
        numpyro.deterministic("theta", mu + tau * eta)

Convert from MCMC#

This first example shows conversion from MCMC

# fit with MCMC
nuts = NUTS(eight_schools_model)
mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)

# sample the posterior predictive
predictive = Predictive(eight_schools_model, mcmc.get_samples())
samples_predictive = predictive(random.PRNGKey(1), J=J, sigma=sigma)

# Convert to MCMC
idata_mcmc = az.from_numpyro(mcmc, posterior_predictive=samples_predictive)
idata_mcmc
/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_61101/3262796440.py:3: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
sample: 100%|██████████| 2000/2000 [00:00<00:00, 3287.85it/s, 7 steps of size 4.01e-01. acc. prob=0.90]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 9001.07it/s, 7 steps of size 4.16e-01. acc. prob=0.86]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7635.89it/s, 15 steps of size 3.71e-01. acc. prob=0.92]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7994.02it/s, 15 steps of size 3.47e-01. acc. prob=0.93]
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

Convert from SVI with Autoguide#

eight_schools_guide = autoguide.AutoNormal(eight_schools_model, init_loc_fn=numpyro.infer.init_to_median(num_samples=100))
svi = SVI(
    eight_schools_model, 
    guide=eight_schools_guide,
    optim=numpyro.optim.Adam(0.01),
    loss = Trace_ELBO()
)
svi_result = svi.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)

# sample the posterior predictive
predictive_svi = Predictive(eight_schools_model, guide=eight_schools_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi = predictive_svi(random.PRNGKey(1), J=J, sigma=sigma)


idata_svi = az.from_numpyro_svi(
    svi,
    svi_result=svi_result,
    model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
    num_samples = 4000, # number of samples to draw in the posterior
    posterior_predictive=samples_predictive_svi
)
idata_svi
100%|██████████| 10000/10000 [00:00<00:00, 11110.33it/s, init loss: 53.6608, avg. loss [9501-10000]: 31.6204]
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

Converting from SVI with a custom guide function#

svi_custom_guide = SVI(
    eight_schools_model, 
    guide=eight_schools_custom_guide,
    optim=numpyro.optim.Adam(0.01),
    loss = Trace_ELBO()
)
svi_custom_guide_result = svi_custom_guide.run(random.PRNGKey(0), num_steps=10000, J=J, sigma=sigma, y=y_obs)

# sample the posterior predictive
predictive_svi_custom = Predictive(eight_schools_model, guide=eight_schools_custom_guide, params=svi_result.params, num_samples=4000)
samples_predictive_svi_custom = predictive_svi_custom(random.PRNGKey(1), J=J, sigma=sigma)

idata_svi_custom_guide = az.from_numpyro_svi(
    svi_custom_guide,
    svi_result=svi_custom_guide_result,
    model_kwargs=dict(J=J, sigma=sigma, y=y_obs), # SVI requires providing the fit args/kwargs
    num_samples = 4000, # number of samples to draw in the posterior
    posterior_predictive=samples_predictive_svi_custom
)
idata_svi_custom_guide
100%|██████████| 10000/10000 [00:00<00:00, 10763.36it/s, init loss: 34.9525, avg. loss [9501-10000]: 31.6279]
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

Automatically Labelling Event Dims#

NumPyro batch dims are automatically labelled according to their corresponding plate names. In order to label event dims, we add infer={"event_dims": dim_labels} to the numpyro.sample statement as shown below:

def eight_schools_model_zsn(J, sigma, y=None):
    mu = numpyro.sample("mu", dist.Normal(0, 5))
    tau = numpyro.sample("tau", dist.HalfCauchy(5))
    eta = numpyro.sample(
        "eta", 
        dist.ZeroSumNormal(tau, event_shape=(J,)),
        # note: this allows arviz to infer the event dimension labels
        infer={"event_dims":["J"]}
    )
    with numpyro.plate("J", J):
        theta = numpyro.deterministic("theta", mu + eta)
        return numpyro.sample("obs", dist.Normal(theta, sigma), obs=y)


# fit with MCMC
nuts = NUTS(eight_schools_model_zsn)
mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
mcmc2.run(random.PRNGKey(0), J=J, sigma=sigma, y=y_obs, extra_fields=("num_steps", "energy"),)


# sample the posterior predictive
predictive2 = Predictive(eight_schools_model, mcmc2.get_samples())
samples_predictive2 = predictive2(random.PRNGKey(1), J=J, sigma=sigma)

# Convert to MCMC
idata_mcmc2 = az.from_numpyro(mcmc2, posterior_predictive=samples_predictive2)
/var/folders/3n/bm6t53l15kddzf7prg_kj3140000gn/T/ipykernel_61101/306760900.py:17: UserWarning: There are not enough devices to run parallel chains: expected 4 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(4)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.
  mcmc2 = MCMC(nuts, num_warmup = 1000, num_samples = 1000, num_chains=4)
sample: 100%|██████████| 2000/2000 [00:00<00:00, 2525.59it/s, 3 steps of size 2.56e-01. acc. prob=0.91]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6867.14it/s, 15 steps of size 1.99e-01. acc. prob=0.90]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 6385.59it/s, 15 steps of size 2.85e-01. acc. prob=0.83]
sample: 100%|██████████| 2000/2000 [00:00<00:00, 7863.09it/s, 3 steps of size 2.61e-01. acc. prob=0.83] 

Notice that eta is labelled appropriately with J

idata_mcmc2
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*

Extending NumPyro Conversion to other Inference Objects#

NumPyroInferenceAdapter can be leveraged to extend ArviZ conversion to other NumPyro Inference Objects (such as the NestedSampler)

The example below uses the SVI implementation as an example, where an adapter class is created that inherits the NumPyroInferenceAdapter base class

class SVIAutoGuideAdapter(az.NumPyroInferenceAdapter):
    """Adapter for SVI to standardize attributes and methods with other inference objects."""

    def __init__(
        self, svi, *, svi_result,  model_args=None, model_kwargs=None, num_samples = 1000,
    ):
        # Necessary: Specify the inference object, internal model fn, inputs, and sample shape
        super().__init__(
            svi,
            model=getattr(svi.guide, "model", svi.model),
            model_args=model_args,
            model_kwargs=model_kwargs,
            sample_shape=(num_samples,),
        )
        self.result_obj = svi_result # saving this to help with posterior sampling

    # Necessary: Specify the sample dim names and shape. ie MCMC is ("chain", "draw")
    @property
    def sample_dims(self):
        return ["sample"]

    # Necessary: Specify how to get posterior samples from the inference objects
    # for SVI in numpyro, we need to sample from the guide with our SVI params
    def get_samples(self, seed = None, **kwargs):
        key = self.prng_key_func(seed or 0)
        return self.posterior.guide.sample_posterior(
            key,
            self.result_obj.params, # the internal SVI params needed to make predictions
            *self._args,
            sample_shape=self.sample_shape,
            **self._kwargs,
        )

The instantiated adapter can now be passed directly into az.from_numpyro.

adapter = SVIAutoGuideAdapter(
    svi, 
    svi_result=svi_result, 
    model_kwargs=dict(J=J, sigma=sigma, y=y_obs),
    num_samples = 4000
)

idata_svi2 = az.from_numpyro(adapter, posterior_predictive=samples_predictive_svi)
idata_svi2
<xarray.DatasetView> Size: 0B
Dimensions:  ()
Data variables:
    *empty*
%load_ext watermark
%watermark -n -u -v -iv -w
Last updated: Sun Mar 15 2026

Python implementation: CPython
Python version       : 3.12.10
IPython version      : 9.4.0

jax       : 0.9.0.1
numpy     : 2.3.2
numpyro   : 0.20.0
arviz_base: 0.7.0.dev0

Watermark: 2.5.0