Skip to content

Commit

Permalink
Adjusted mclmc updates (#67)
Browse files Browse the repository at this point in the history
* WIP

* Updated run_inference_algorithm

* UPDATE EXAMPLE

* UPDATE EXAMPLE

* UPDATE EXAMPLE

* UPDATE EXAMPLE

* mams

* mams

* adjusted

* adjusted

* update

* update adjusted mclmc
  • Loading branch information
reubenharry authored Feb 9, 2025
1 parent 960dea0 commit dc83595
Showing 1 changed file with 44 additions and 51 deletions.
95 changes: 44 additions & 51 deletions book/algorithms/mclmc.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ MCLMC in Blackjax comes with a tuning algorithm which attempts to find optimal v

An example is given below, of tuning and running a chain for a 1000 dimensional Gaussian target (of which a 2 dimensional marginal is plotted):

```{code-cell}
```{code-cell} ipython3
:tags: [hide-cell]
import matplotlib.pyplot as plt
Expand All @@ -66,7 +66,7 @@ from numpyro.infer.util import initialize_model
rng_key = jax.random.key(int(date.today().strftime("%Y%m%d")))
```

```{code-cell}
```{code-cell} ipython3
def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desired_energy_variance= 5e-4):
init_key, tune_key, run_key = jax.random.split(key, 3)
Expand All @@ -76,16 +76,17 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desire
)
# build the kernel
kernel = lambda sqrt_diag_cov : blackjax.mcmc.mclmc.build_kernel(
kernel = lambda inverse_mass_matrix : blackjax.mcmc.mclmc.build_kernel(
logdensity_fn=logdensity_fn,
integrator=blackjax.mcmc.integrators.isokinetic_mclachlan,
sqrt_diag_cov=sqrt_diag_cov,
inverse_mass_matrix=inverse_mass_matrix,
)
# find values for L and step_size
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
_
) = blackjax.mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand Down Expand Up @@ -115,7 +116,7 @@ def run_mclmc(logdensity_fn, num_steps, initial_position, key, transform, desire
return samples, blackjax_state_after_tuning, blackjax_mclmc_sampler_params, run_key
```

```{code-cell}
```{code-cell} ipython3
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
logdensity_fn = lambda x: -0.5 * jnp.sum(jnp.square(x))
Expand All @@ -134,13 +135,13 @@ samples, initial_state, params, chain_key = run_mclmc(
samples.mean()
```

```{code-cell}
```{code-cell} ipython3
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{code-cell}
```{code-cell} ipython3
def visualize_results_gauss(samples, label, color):
x1 = samples[:, 0]
plt.hist(x1, bins= 30, density= True, histtype= 'step', lw= 4, color= color, label= label)
Expand All @@ -165,12 +166,12 @@ ground_truth_gauss()

A natural sanity check is to see if reducing $\epsilon$ changes the inferred distribution to an extent you care about. For example, we can inspect the 1D marginal with a stepsize $\epsilon$ as above, and compare it to a stepsize $\epsilon/2$ (and double the number of steps). We show this comparison below:

```{code-cell}
```{code-cell} ipython3
new_params = params._replace(step_size= params.step_size / 2)
new_num_steps = num_steps * 2
```

```{code-cell}
```{code-cell} ipython3
sampling_alg = blackjax.mclmc(
logdensity_fn,
L=new_params.L,
Expand Down Expand Up @@ -211,7 +212,7 @@ Our task is to find the posterior of the parameters $\{R_n\}_{n =1}^N$, $\sigma$

First, we get the data, define a model using NumPyro, and draw samples:

```{code-cell}
```{code-cell} ipython3
import matplotlib.dates as mdates
from numpyro.examples.datasets import SP500, load_dataset
from numpyro.distributions import StudentT
Expand Down Expand Up @@ -243,7 +244,7 @@ def setup():
setup()
```

```{code-cell}
```{code-cell} ipython3
def from_numpyro(model, rng_key, model_args):
init_params, potential_fn_gen, *_ = initialize_model(
rng_key,
Expand Down Expand Up @@ -272,13 +273,13 @@ rng_key = jax.random.key(42)
logp_sv, x_init = from_numpyro(stochastic_volatility, rng_key, model_args)
```

```{code-cell}
```{code-cell} ipython3
num_steps = 20000
samples, initial_state, params, chain_key = run_mclmc(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key, transform=lambda state, info: state.position)
```

```{code-cell}
```{code-cell} ipython3
def visualize_results_sv(samples, color, label):
R = np.exp(np.array(samples['s'])) # take an exponent to get R
Expand All @@ -297,7 +298,7 @@ plt.legend()
plt.show()
```

```{code-cell}
```{code-cell} ipython3
new_params = params._replace(step_size = params.step_size/2)
new_num_steps = num_steps * 2
Expand All @@ -320,7 +321,7 @@ _, new_samples = blackjax.util.run_inference_algorithm(
)
```

```{code-cell}
```{code-cell} ipython3
setup()
visualize_results_sv(new_samples,'red', 'MCLMC', )
visualize_results_sv(samples,'teal', 'MCLMC (stepsize/2)', )
Expand All @@ -331,7 +332,7 @@ plt.show()

Here, we have again inspected the effect of halving $\epsilon$. This looks OK, but suppose we are interested in the hierarchial parameters in particular, which tend to be harder to infer. We now inspect the marginal of a hierarchical parameter:

```{code-cell}
```{code-cell} ipython3
def visualize_results_sv_marginal(samples, color, label):
# plt.subplot(1, 2, 1)
# plt.hist(samples['nu'], bins = 20, histtype= 'step', lw= 4, density= True, color= color, label= label)
Expand All @@ -355,26 +356,31 @@ If we care about this parameter in particular, we should reduce step size furthe

## Adjusted MCLMC

Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, `step_size` and `L`. `L` is related to the `L` parameter of the unadjusted version, but not identical. The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice.
Blackjax also provides an adjusted version of the algorithm. This also has two hyperparameters, `step_size` and `L`. `L` is related to the `L` parameter of the unadjusted version, but not identical (It determines the length of a proposal, and since momentum is resampled after a proposal, length of proposal determines the momentum decoherence rate). It is also possible to have Langevin noise during the trajectory, although we don't see improvements here.

```{code-cell}
from blackjax.mcmc.adjusted_mclmc import rescale
The tuning algorithm is also similar, but uses a dual averaging scheme to tune the step size. We find in practice that a target MH acceptance rate of 0.9 is a good choice.

**Our recommendation is to use the unadjusted version when possible**, but if you really believe you need the algorithm to be asymptotically unbiased (it's not obvious why you would), you should use the adjusted version, with mass matrix preconditioning, randomized trajectory length and no in-proposal Langevin noise. We encapsulate these best practices in `run_adjusted_mcmc` below:


```{code-cell} ipython3
from blackjax.mcmc.adjusted_mclmc_dynamic import rescale
from blackjax.util import run_inference_algorithm
def run_adjusted_mclmc(
def run_adjusted_mclmc_dynamic(
logdensity_fn,
num_steps,
initial_position,
key,
transform=lambda state, _ : state.position,
diagonal_preconditioning=False,
diagonal_preconditioning=True,
random_trajectory_length=True,
L_proposal_factor=jnp.inf
):
init_key, tune_key, run_key = jax.random.split(key, 3)
initial_state = blackjax.mcmc.adjusted_mclmc.init(
initial_state = blackjax.mcmc.adjusted_mclmc_dynamic.init(
position=initial_position,
logdensity_fn=logdensity_fn,
random_generator_arg=init_key,
Expand All @@ -386,9 +392,9 @@ def run_adjusted_mclmc(
else:
integration_steps_fn = lambda avg_num_integration_steps: lambda _: jnp.ceil(avg_num_integration_steps)
kernel = lambda rng_key, state, avg_num_integration_steps, step_size, sqrt_diag_cov: blackjax.mcmc.adjusted_mclmc.build_kernel(
kernel = lambda rng_key, state, avg_num_integration_steps, step_size, inverse_mass_matrix: blackjax.mcmc.adjusted_mclmc_dynamic.build_kernel(
integration_steps_fn=integration_steps_fn(avg_num_integration_steps),
sqrt_diag_cov=sqrt_diag_cov,
inverse_mass_matrix=inverse_mass_matrix,
)(
rng_key=rng_key,
state=state,
Expand All @@ -402,6 +408,7 @@ def run_adjusted_mclmc(
(
blackjax_state_after_tuning,
blackjax_mclmc_sampler_params,
_
) = blackjax.adjusted_mclmc_find_L_and_step_size(
mclmc_kernel=kernel,
num_steps=num_steps,
Expand All @@ -410,20 +417,20 @@ def run_adjusted_mclmc(
target=target_acc_rate,
frac_tune1=0.1,
frac_tune2=0.1,
frac_tune3=0.0, # our recommendation
frac_tune3=0.1, # our recommendation
diagonal_preconditioning=diagonal_preconditioning,
)
step_size = blackjax_mclmc_sampler_params.step_size
L = blackjax_mclmc_sampler_params.L
alg = blackjax.adjusted_mclmc(
alg = blackjax.adjusted_mclmc_dynamic(
logdensity_fn=logdensity_fn,
step_size=step_size,
integration_steps_fn=lambda key: jnp.ceil(
jax.random.uniform(key) * rescale(L / step_size)
),
sqrt_diag_cov=blackjax_mclmc_sampler_params.sqrt_diag_cov,
inverse_mass_matrix=blackjax_mclmc_sampler_params.inverse_mass_matrix,
L_proposal_factor=L_proposal_factor,
)
Expand All @@ -439,45 +446,31 @@ def run_adjusted_mclmc(
return out
```

```{code-cell}
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
sample_key, rng_key = jax.random.split(rng_key)
samples = run_adjusted_mclmc(
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
num_steps=1000,
initial_position=jnp.ones((1000,)),
key=sample_key,
)
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{code-cell}
```{code-cell} ipython3
# run the algorithm on a high dimensional gaussian, and show two of the dimensions
sample_key, rng_key = jax.random.split(rng_key)
samples = run_adjusted_mclmc(
samples = run_adjusted_mclmc_dynamic(
logdensity_fn=lambda x: -0.5 * jnp.sum(jnp.square(x)),
num_steps=1000,
initial_position=jnp.ones((1000,)),
key=sample_key,
random_trajectory_length=False,
L_proposal_factor=1.25,
)
plt.scatter(x=samples[:, 0], y=samples[:, 1], alpha=0.1)
plt.axis("equal")
plt.title("Scatter Plot of Samples")
```

```{bibliography}
:filter: docname in docnames
```

```{code-cell} ipython3
num_steps = 10000
adjusted_samples = run_adjusted_mclmc_dynamic(logdensity_fn= logp_sv, num_steps= num_steps, initial_position= x_init, key= sample_key)
```

```{code-cell}
```{code-cell} ipython3
setup()
visualize_results_sv(adjusted_samples, color= 'navy', label= 'volatility posterior')
plt.legend()
plt.show()
```

0 comments on commit dc83595

Please sign in to comment.