Skip to content

Commit 7c7a6f6

Browse files
committed
Extra detail
1 parent 0e790f3 commit 7c7a6f6

File tree

1 file changed

+86
-21
lines changed

1 file changed

+86
-21
lines changed

src/transforms.qmd

+86-21
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,7 @@ For example, `logpdf_with_trans` can directly give us $\log(q(\mathbf{y}))$:
421421
B.logpdf_with_trans(LogNormal(), x, true)
422422
```
423423

424-
## Why is this useful for sampling anyway?
424+
## The need for bijectors in MCMC
425425

426426
Constraints pose a problem for pretty much any kind of numerical method, and sampling is no exception to this.
427427
The problem is that for any value $x$ outside of the support of a constrained distribution, $p(x)$ will be zero, and the logpdf will be $-\infty$.
@@ -432,7 +432,7 @@ This post is already really long, and does not have quite enough space to explai
432432
If you need more information on these, please read e.g. chapter 11 of Bishop.
433433
:::
434434

435-
### Metropolis–Hastings... fine?
435+
### Metropolis–Hastings: fine?
436436

437437
This alone is not enough to cause issues for Metropolis–Hastings.
438438
Here's an extremely barebones implementation of a random walk Metropolis algorithm:
@@ -474,13 +474,14 @@ histogram(samples_with_mh, bins=0:0.1:5; xlims=(0, 5))
474474
```
475475

476476
In this MH implementation, the only place where $p(x)$ comes into play is in the acceptance probability.
477-
Since we make sure to start the sampling at a point within the support of the distribution, `p(x)` will be nonzero.
478477

479-
If the proposal step causes `x_proposal` to be outside the support, then `p(x_proposal)` will be zero, and the acceptance probability (`p(x_proposal)/p(x)`) will be zero.
478+
As long as we make sure to start the sampling at a point within the support of the distribution, `p(x)` will be nonzero.
479+
If the proposal step generates an `x_proposal` that is outside the support, `p(x_proposal)` will be zero, and the acceptance probability (`p(x_proposal)/p(x)`) will be zero.
480480
So such a step will never be accepted, and the sampler will continue to stay within the support of the distribution.
481+
481482
Although this does mean that we may find ourselves having a higher reject rate than usual, and thus less efficient sampling, it at least does not cause the algorithm to become unstable or crash.
482483

483-
### Hamiltonian Monte Carlo... not so fine
484+
### Hamiltonian Monte Carlo: not so fine
484485

485486
The _real_ problem comes with gradient-based methods like Hamiltonian Monte Carlo (HMC).
486487
Here's an equally barebones implementation of HMC.
@@ -582,48 +583,112 @@ z += timestep * r # (2)
582583
r -= (timestep / 2) * dEdz(z) # (3)
583584
```
584585

585-
Here, `z` is the position.
586+
Here, `z` is the position and `r` the momentum.
586587
Since we start our sampler inside the support of the distribution (by supplying a good initial point), `dEdz(z)` will start off being well-defined on line (1).
587588
However, after `r` is updated on line (1), `z` is updated again on line (2), and _this_ value of `z` may well be outside of the support.
588589
At this point, `dEdz(z)` will be `NaN`, and the final update to `r` on line (3) will also cause it to be `NaN`.
589590

590591
Even if we're lucky enough for an individual integration step to not move `z` outside the support, there are many integration steps per sampler step, and many sampler steps, and so the chances of this happening at some point are quite high.
591592

592-
It's possible to choose your integration parameters carefully to reduce the risk of this happening.
593+
It's possible to choose our integration parameters carefully to reduce the risk of this happening.
593594
For example, we could set the integration timestep to be _really_ small, thus reducing the chance of making a move outside the support.
594595
But that will just lead to a very slow exploration of parameter space, and in general, we should like to avoid this problem altogether.
595596

596597
### Rescuing HMC
597598

598599
Perhaps unsurprisingly, the answer to this is to transform the underlying distribution to an unconstrained one and sample from that instead.
599-
However, to preserve the correct behaviour, we have to make sure that we include the pesky Jacobian term when sampling from the transformed distribution.
600-
Bijectors.jl can do all of this for us.
600+
However, we have to make sure that we include the pesky Jacobian term when sampling from the transformed distribution.
601+
That's where Bijectors.jl can come in.
602+
603+
The main change we need to make is to pass a modified version of the function `p` to our HMC sampler.
604+
Recall back at the very start, we transformed $p(x)$ into $q(y)$, and said that
605+
606+
$$q(y) = p(x) \left| \frac{\mathrm{d}x}{\mathrm{d}y} \right|.$$
601607

602-
The main thing we need to do is to pass a modified version of the function `p` to our HMC sampler.
603-
Recall the problem is that our `p` is zero outside the support of the distribution.
604-
What we can do is to instead specify `p` as the pdf of our transformed distribution, evaluated at the transformed value of `x` (which we'll call `y`).
608+
What we want the HMC sampler to see is the transformed distribution $q(y)$, not the original distribution $p(x)$.
609+
And Bijectors.jl lets us calculate $\log(q(y))$ using `logpdf_with_trans(p, x, true)`:
605610

606611
```{julia}
607612
d = LogNormal()
608-
# Calling pdf() on a transformed distribution automatically includes
609-
# the Jacobian term
610-
p_transformed(y) = pdf(B.transformed(d), y)
611-
# These definitions are the same as before
612-
E(z) = -log(p_transformed(z))
613+
f = B.bijector(d) # Transformation function
614+
f_inv = B.inverse(f) # Inverse transformation function
615+
616+
function logq(y)
617+
x = f_inv(y)
618+
return B.logpdf_with_trans(d, x, true)
619+
end
620+
# These definitions are the same as before, except that
621+
# the call to `log` has been moved up into logq rather
622+
# than in E.
623+
E(z) = -logq(z)
613624
dEdz(z) = ForwardDiff.derivative(E, z)
614625
```
615626

616-
When we run HMC on this, it will give us back samples of `y`, not `x`.
617-
So we can untransform them, and voilà, our HMC sampler works again!
627+
The `exp`/`log` wrapping is a bit awkward.
628+
In practice we would only ever work on the log scale, but
629+
630+
Now, because our transformed distribution is unconstrained, we can evaluate `E` and `dEdz` at any point, and sample with more confidence:
618631

619632
```{julia}
620633
samples_with_hmc = hmc(1.0, E, dEdz, 5000)
634+
samples_with_hmc[1:5]
635+
```
636+
637+
No sampling errors this time... yay!
638+
We have to remember that when we run HMC on this, it will give us back samples of `y`, not `x`.
639+
So we can untransform them:
621640

622-
bijector = B.bijector(d)
623-
samples_with_hmc_untransformed = B.inverse(bijector)(samples_with_hmc)
641+
```{julia}
642+
samples_with_hmc_untransformed = f_inv(samples_with_hmc)
624643
histogram(samples_with_hmc_untransformed, bins=0:0.1:5; xlims=(0, 5))
625644
```
626645

646+
We can also check that the mean and variance of the samples are what we expect them to be.
647+
From [Wikipedia](https://en.wikipedia.org/wiki/Log-normal_distribution), the mean and variance of a log-normal distribution are respectively $\exp(\mu + \sigma^2/2)$ and $[\exp(\sigma^2) - 1]\exp(2\mu + \sigma^2)$.
648+
For our log-normal distribution, we set $\mu = 0$ and $\sigma = 1$, so the mean and variance should be $1.6487$ and $4.6707$ respectively.
649+
650+
```{julia}
651+
println(" mean : $(mean(samples_with_hmc_untransformed))")
652+
println("variance : $(var(samples_with_hmc_untransformed))")
653+
```
654+
655+
::: {.callout-note}
656+
You might notice that the variance is a little bit off.
657+
The truth is that it's actually quite tricky to get an accurate variance when sampling from a log-normal distribution.
658+
You can see this even with Turing.jl itself:
659+
660+
```{julia}
661+
using Turing
662+
setprogress!(false)
663+
@model ln() = x ~ LogNormal()
664+
chain = sample(ln(), HMC(0.2, 3), 5000)
665+
(mean(chain[:x]), var(chain[:x]))
666+
```
667+
:::
668+
669+
The importance of the Jacobian term here isn't to enable sampling _per se_.
670+
Because the resulting distribution is unconstrained, we could have still sampled from it without using the Jacobian.
671+
However, adding the Jacobian is what ensures that when we un-transform the samples, we get the correct distribution.
672+
673+
This is what happens if we don't include the Jacobian term.
674+
In this `logq_wrong`, we've un-transformed `y` to `x` and calculated the logpdf with respect to its original distribution.
675+
This is exactly the same mistake that we made at the start of this article with `naive_logpdf`.
676+
677+
```{julia}
678+
function logq_wrong(y)
679+
x = f_inv(y)
680+
return logpdf(d, x)
681+
end
682+
E(z) = -logq_wrong(z)
683+
dEdz(z) = ForwardDiff.derivative(E, z)
684+
samples_questionable = hmc(1.0, E, dEdz, 5000)
685+
samples_questionable_untransformed = f_inv(samples_questionable)
686+
687+
println(" mean : $(mean(samples_questionable_untransformed))")
688+
println("variance : $(var(samples_questionable_untransformed))")
689+
```
690+
691+
You can see that even though the sampling ran fine without errors, the summary statistics are completely wrong.
627692

628693
## How does DynamicPPL use bijectors?
629694

0 commit comments

Comments
 (0)