diff --git a/book/algorithms/mclmc.md b/book/algorithms/mclmc.md index 77889ea..051441f 100644 --- a/book/algorithms/mclmc.md +++ b/book/algorithms/mclmc.md @@ -44,7 +44,7 @@ MCLMC in Blackjax comes with a tuning algorithm which attempts to find optimal v An example is given below, of tuning and running a chain for a 1000 dimensional Gaussian target (of which a 2 dimensional marginal is plotted): -```{code-cell} +```{code-cell} ipython3 :tags: [hide-cell] import matplotlib.pyplot as plt @@ -66,7 +66,7 @@ from numpyro.infer.util import initialize_model rng_key = jax.random.key(int(date.today().strftime("%Y%m%d"))) ``` -```{code-cell} +```{code-cell} ipython3 def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desired_energy_variance= 5e-4): init_key, tune_key, run_key = jax.random.split(key, 3) @@ -76,16 +76,17 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desire ) # build the kernel - kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel( + kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel( logdensity_fn=logdensity_fn, integrator=blackjax.mcmc.integrators.isokinetic_mclachlan, - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, ) # find values for L and step_size ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + _ ) = blackjax.mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -115,7 +116,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desire return samples, blackjax_state_after_tuning, blackjax_mclmc_sampler_params, run_key ``` -```{code-cell} +```{code-cell} ipython3 # run the algorithm on a high dimensional gaussian, and show two of the dimensions logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x)) @@ -134,13 +135,13 @@ samples, initial_state, params, chain_key = run_mclmc( samples.mean() ``` -```{code-cell} +```{code-cell} ipython3 plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1) plt.axis("equal") plt.title("Scatter Plot of Samples") ``` -```{code-cell} +```{code-cell} ipython3 def visualize_results_gauss(samples, label, color): x1 = samples[:, 0] plt.hist(x1, bins= 30, density= True, histtype= 'step', lw= 4, color= color, label= label) @@ -165,12 +166,12 @@ ground_truth_gauss() A natural sanity check is to see if reducing $\epsilon$ changes the inferred distribution to an extent you care about. For example, we can inspect the 1D marginal with a stepsize $\epsilon$ as above, and compare it to a stepsize $\epsilon/2$ (and double the number of steps). We show this comparison below: -```{code-cell} +```{code-cell} ipython3 new_params = params._replace(step_size= params.step_size / 2) new_num_steps = num_steps * 2 ``` -```{code-cell} +```{code-cell} ipython3 sampling_alg = blackjax.mclmc( logdensity_fn, L=new_params.L, @@ -211,7 +212,7 @@ Our task is to find the posterior of the parameters $\{R_n\}_{n =1}^N$, $\sigma$ First, we get the data, define a model using NumPyro, and draw samples: -```{code-cell} +```{code-cell} ipython3 import matplotlib.dates as mdates from numpyro.examples.datasets import SP500, load_dataset from numpyro.distributions import StudentT @@ -243,7 +244,7 @@ def setup(): setup() ``` -```{code-cell} +```{code-cell} ipython3 def from_numpyro(model, rng_key, model_args): init_params, potential_fn_gen, *_ = initialize_model( rng_key, @@ -272,13 +273,13 @@ rng_key = jax.random.key(42) logp_sv, x_init = from_numpyro(stochastic_volatility, rng_key, model_args) ``` -```{code-cell} +```{code-cell} ipython3 num_steps = 20000 samples, initial_state, params, chain_key = run_mclmc(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key, transform=lambda state, info: state.position) ``` -```{code-cell} +```{code-cell} ipython3 def visualize_results_sv(samples, color, label): R = np.exp(np.array(samples['s'])) # take an exponent to get R @@ -297,7 +298,7 @@ plt.legend() plt.show() ``` -```{code-cell} +```{code-cell} ipython3 new_params = params._replace(step_size = params.step_size/2) new_num_steps = num_steps * 2 @@ -320,7 +321,7 @@ _, new_samples = blackjax.util.run_inference_algorithm( ) ``` -```{code-cell} +```{code-cell} ipython3 setup() visualize_results_sv(new_samples,'red', 'MCLMC', ) visualize_results_sv(samples,'teal', 'MCLMC (stepsize/2)', ) @@ -331,7 +332,7 @@ plt.show() Here, we have again inspected the effect of halving $\epsilon$. This looks OK, but suppose we are interested in the hierarchial parameters in particular, which tend to be harder to infer. We now inspect the marginal of a hierarchical parameter: -```{code-cell} +```{code-cell} ipython3 def visualize_results_sv_marginal(samples, color, label): # plt.subplot(1, 2, 1) # plt.hist(samples['nu'], bins = 20, histtype= 'step', lw= 4, density= True, color= color, label= label) @@ -355,26 +356,31 @@ If we care about this parameter in particular, we should reduce step size furthe ## Adjusted MCLMC -Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, `step_size` and `L`. `L` is related to the `L` parameter of the unadjusted version, but not identical. The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice. +Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, `step_size` and `L`. `L` is related to the `L` parameter of the unadjusted version, but not identical (It determines the length of a proposal, and since momentum is resampled after a proposal, length of proposal determines the momentum decoherence rate). It is also possible to have Langevin noise during the trajectory, although we don't see improvements here. -```{code-cell} -from blackjax.mcmc.adjusted_mclmc import rescale +The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice. + +**Our recommendation is to use the unadjusted version when possible**, but if you really believe you need the algorithm to be asymptotically unbiased (it's not obvious why you would), you should use the adjusted version, with mass matrix preconditioning, randomized trajectory length and no in-proposal Langevin noise. We encapsulate these best practices in `run_adjusted_mcmc` below: + + +```{code-cell} ipython3 +from blackjax.mcmc.adjusted_mclmc_dynamic import rescale from blackjax.util import run_inference_algorithm -def run_adjusted_mclmc( +def run_adjusted_mclmc_dynamic( logdensity_fn, num_steps, initial_position, key, transform=lambda state, _ : state.position, - diagonal_preconditioning=False, + diagonal_preconditioning=True, random_trajectory_length=True, L_proposal_factor=jnp.inf ): init_key, tune_key, run_key = jax.random.split(key, 3) - initial_state = blackjax.mcmc.adjusted_mclmc.init( + initial_state = blackjax.mcmc.adjusted_mclmc_dynamic.init( position=initial_position, logdensity_fn=logdensity_fn, random_generator_arg=init_key, @@ -386,9 +392,9 @@ def run_adjusted_mclmc( else: integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(avg_num_integration_steps) - kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel( + kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel( integration_steps_fn=integration_steps_fn(avg_num_integration_steps), - sqrt_diag_cov=sqrt_diag_cov, + inverse_mass_matrix=inverse_mass_matrix, )( rng_key=rng_key, state=state, @@ -402,6 +408,7 @@ def run_adjusted_mclmc( ( blackjax_state_after_tuning, blackjax_mclmc_sampler_params, + _ ) = blackjax.adjusted_mclmc_find_L_and_step_size( mclmc_kernel=kernel, num_steps=num_steps, @@ -410,20 +417,20 @@ def run_adjusted_mclmc( target=target_acc_rate, frac_tune1=0.1, frac_tune2=0.1, - frac_tune3=0.0, # our recommendation + frac_tune3=0.1, # our recommendation diagonal_preconditioning=diagonal_preconditioning, ) step_size = blackjax_mclmc_sampler_params.step_size L = blackjax_mclmc_sampler_params.L - alg = blackjax.adjusted_mclmc( + alg = blackjax.adjusted_mclmc_dynamic( logdensity_fn=logdensity_fn, step_size=step_size, integration_steps_fn=lambda key: jnp.ceil( jax.random.uniform(key) * rescale(L / step_size) ), - sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov, + inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix, L_proposal_factor=L_proposal_factor, ) @@ -439,45 +446,31 @@ def run_adjusted_mclmc( return out ``` -```{code-cell} -# run the algorithm on a high dimensional gaussian, and show two of the dimensions - -sample_key, rng_key = jax.random.split(rng_key) -samples = run_adjusted_mclmc( - logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)), - num_steps=1000, - initial_position=jnp.ones((1000,)), - key=sample_key, -) -plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1) -plt.axis("equal") -plt.title("Scatter Plot of Samples") -``` - -```{code-cell} +```{code-cell} ipython3 # run the algorithm on a high dimensional gaussian, and show two of the dimensions sample_key, rng_key = jax.random.split(rng_key) -samples = run_adjusted_mclmc( +samples = run_adjusted_mclmc_dynamic( logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)), num_steps=1000, initial_position=jnp.ones((1000,)), key=sample_key, - random_trajectory_length=False, - L_proposal_factor=1.25, ) plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1) plt.axis("equal") plt.title("Scatter Plot of Samples") ``` -```{bibliography} -:filter: docname in docnames -``` - +```{code-cell} ipython3 +num_steps = 10000 +adjusted_samples = run_adjusted_mclmc_dynamic(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key) ``` -```{code-cell} +```{code-cell} ipython3 +setup() +visualize_results_sv(adjusted_samples, color= 'navy', label= 'volatility posterior') +plt.legend() +plt.show() ```