arviz_base.SVIAdapter

arviz_base.SVIAdapter#

class arviz_base.SVIAdapter(svi, *, svi_result, model_args=None, model_kwargs=None, num_samples=1000)[source]#

Adapter for SVI to standardize attributes and methods with other inference objects.

__init__(svi, *, svi_result, model_args=None, model_kwargs=None, num_samples=1000)[source]#

Initialize SVI adapter for variational inference results.

Parameters:
svinumpyro.infer.SVI

Fitted SVI object.

svi_resultnumpyro.infer.svi.SVIRunResult

SVI optimization results containing learned parameters.

model_argstuple, optional

Positional arguments for the model.

model_kwargsdict, optional

Keyword arguments for the model.

num_samplesint, default 1000

Number of posterior samples to generate from the guide.

Methods

__init__(svi, *, svi_result[, model_args, ...])

Initialize SVI adapter for variational inference results.

get_sample_stats(**kwargs)

Get sample stats from the inference object (e.g., divergences for MCMC).

get_samples([seed])

Get posterior samples from the inference object.

Attributes

sample_dims

Return the sample dimension names.