Skip to content

Commit 0e790f3

Browse files
committed
Almost finish Bijectors post
1 parent 15e64aa commit 0e790f3

File tree

2 files changed

+248
-27
lines changed

2 files changed

+248
-27
lines changed

Project.toml

+3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
33
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
44
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
5+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
6+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
57
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
68
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
9+
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
710
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

src/transforms.qmd

+245-27
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ log(1 / sqrt(2π) * exp(-samples[1]^2 / 2))
4747

4848
## Sampling from a transformed distribution
4949

50-
Now say that $x$ is distributed according to `Normal()`, and we want to draw samples from $y = \exp(x)$.
51-
The distribution of $y$ is known as a [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution).
50+
Say that $x$ is distributed according to `Normal()`, and we want to draw samples of $y = \exp(x)$.
51+
Now, $y$ is itself a random variable, and like any other random variable, will have a probability distribution, which we'll call $q(y)$.
5252

53-
For illustration purposes, let's make our own `MyLogNormal` distribution that we can sample from: see Distribution.jl's documentation on custom distributions [here](https://juliastats.org/Distributions.jl/stable/extends/#Univariate-Distribution).
54-
(Distributions already defines its own `LogNormal`, so we have to use a different name.)
53+
In this specific case, the distribution of $y$ is known as a [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution).
54+
For illustration purposes, let's try to implement our own `MyLogNormal` distribution that we can sample from.
55+
(Distributions.jl already defines its own `LogNormal`, so we have to use a different name.)
5556

5657
```{julia}
5758
struct MyLogNormal <: ContinuousUnivariateDistribution
@@ -63,41 +64,44 @@ MyLogNormal() = MyLogNormal(0.0, 1.0)
6364
Base.rand(rng::Random.AbstractRNG, d::MyLogNormal) = exp(rand(rng, Normal(d.μ, d.σ)))
6465
```
6566

66-
Great, now we can do the same as above:
67+
Now we can do the same as above:
6768

6869
```{julia}
6970
samples_lognormal = rand(MyLogNormal(), 5000)
70-
# Cut off the tail for clearer visualization
71+
# Cut off the tail for clearer visualisation
7172
histogram(samples_lognormal, bins=0:0.1:5; xlims=(0, 5))
7273
```
7374

7475
How do we implement `logpdf` for our new distribution, though?
76+
Or in other words, if we observe a sample $y$, how do we know what the probability of drawing that sample was?
7577

76-
Naively, we might think to just un-transform the variable `y`, and then use the `logpdf` of the normal distribution.
78+
Naively, we might think to just un-transform the variable `y` by reversing the exponential, i.e. taking the logarithm
79+
We could then use the `logpdf` of the original distribution of `x`.
7780

7881
```{julia}
79-
bad_logpdf(d::MyLogNormal, y) = logpdf(Normal(d.μ, d.σ), log(y))
82+
naive_logpdf(d::MyLogNormal, y) = logpdf(Normal(d.μ, d.σ), log(y))
8083
```
8184

82-
We can compare this function against the logpdf implemented in Distributions.jl.
83-
(The name chosen here certainly foreshadows that it's not going to be correct, though!)
85+
We can compare this function against the logpdf implemented in Distributions.jl:
8486

8587
```{julia}
8688
println("Sample : $(samples_lognormal[1])")
8789
println("Expected : $(logpdf(LogNormal(), samples_lognormal[1]))")
88-
println("Actual : $(bad_logpdf(MyLogNormal(), samples_lognormal[1]))")
90+
println("Actual : $(naive_logpdf(MyLogNormal(), samples_lognormal[1]))")
8991
```
9092

93+
Clearly this approach is not quite correct!
94+
9195
## The derivative
9296

9397
Fundamentally, the reason why this doesn't work is because transforming a (continuous) distribution causes probability density to be stretched and otherwise moved around.
9498

9599
::: {.callout-note}
96-
There are various posts on the Internet that explain this visually; I'm too lazy to draw a diagram _right now_, but I might do it later.
100+
There are various posts on the Internet that explain this visually.
97101
:::
98102

99-
I personally find it most useful to not talk about probability density itself, but instead to make it more concrete by talking about actual probabilities.
100-
If we think about the normal distribution as a continuous curve, what the probability density function $p(x)$ really tells us is that for any two points $a$ and $b$ (where $a \leq b$), the probability of drawing a sample from the interval $[a, b]$ is the area under the curve, i.e.
103+
A perhaps useful approach is to not talk about _probability densities_, but instead to make it more concrete by talking about actual _probabilities_.
104+
If we think about the normal distribution as a continuous curve, what the probability density function $p(x)$ really tells us is that: for any two points $a$ and $b$ (where $a \leq b$), the probability of drawing a sample between $a$ and $b$ is the corresponding area under the curve, i.e.
101105

102106
$$\int_a^b p(x) \, \mathrm{d}x.$$
103107

@@ -123,10 +127,10 @@ $$\int_{x=a}^{x=b} p(x) \, \mathrm{d}x
123127

124128
from which we can read off $q(y) = p(\log(y)) / y$.
125129

126-
In contrast, when we implemented `bad_logpdf`
130+
In contrast, when we implemented `naive_logpdf`
127131

128132
```{julia}
129-
bad_logpdf(d::MyLogNormal, y) = logpdf(Normal(d.μ, d.σ), log(y))
133+
naive_logpdf(d::MyLogNormal, y) = logpdf(Normal(d.μ, d.σ), log(y))
130134
```
131135

132136
that was the equivalent of saying that $q(y) = p(\log(y))$.
@@ -177,12 +181,19 @@ $$q(y_1, y_2) = p(x_1, x_2) \left| \det(\mathcal{J}) \right|.$$
177181

178182
This is the same as equation (11.9) in Bishop, except that he denotes the absolute value of the determinant with just $|\mathcal{J}|$.
179183

180-
::: {.callout-note}
181-
In different contexts the Jacobian can have different 'numerators' and 'denominators' in the partial derivatives.
182-
For example, if $\mathbf{y} = f(\mathbf{x})$, then it's common to write $\mathbf{J}$ as a matrix of partial derivatives of elements of $y$ with respect to elements of $x$.
184+
::: {.callout-important}
185+
Note that, if we have a function $f$ mapping $\mathbf{x}$ to $\mathbf{y}$, then the Jacobian matrix $\mathbf{J}$ (sometimes denoted $\mathbf{J}_f$) is usually defined _the other way round_:
186+
187+
$$\mathbf{J} = \begin{pmatrix}
188+
\partial y_1/\partial x_1 & \partial y_1/\partial x_2 \\
189+
\partial y_2/\partial x_1 & \partial y_2/\partial x_2
190+
\end{pmatrix}.$$
191+
183192
Indeed, later in this article we will see that Bijectors.jl uses this convention.
193+
This is why we have denoted this 'inverse' Jacobian as $\mathcal{J}$, rather than $\mathbf{J}$.
184194

185-
It is always the case, though, that the elements of the 'numerator' vary with rows and the elements of the 'denominator' vary with columns.
195+
$\mathcal{J}$ is really the Jacobian of the inverse function $f^{-1}$.
196+
As it turns out, the matrix $\mathcal{J}$ is also the inverse of $\mathbf{J}$.
186197
:::
187198

188199
The rest of this section will be devoted to an example to show that this works, and contains some slightly less pretty mathematics.
@@ -301,19 +312,17 @@ Technically, the bijections in Bijectors.jl are functions $f: X \to Y$ for which
301312
- $f$ is continuously differentiable, i.e. the derivative $\mathrm{d}f(x)/\mathrm{d}x$ exists and is continuous (over the domain of interest $X$);
302313
- If $f^{-1}: Y \to X$ is the inverse of $f$, then that is also continuously differentiable (over _its_ own domain, i.e. $Y$).
303314

304-
These are called diffeomorphisms ([Wikipedia](https://en.wikipedia.org/wiki/Diffeomorphism)).
315+
The technical mathematical term for this is a diffeomorphism ([Wikipedia](https://en.wikipedia.org/wiki/Diffeomorphism)), but we call them 'bijectors'.
305316

306317
When thinking about continuous differentiability, it's important to be conscious of the domains or codomains that we care about.
307318
For example, taking the inverse function $\log(y)$ from above, its derivative is $1/y$, which is not continuous at $y = 0$.
308319
However, we specified that the bijection $y = \exp(x)$ maps values of $x \in (-\infty, \infty)$ to $y \in (0, \infty)$, so the point $y = 0$ is not within the domain of the inverse function.
309320
:::
310321

311-
It's not entirely clear to me who first coined the term biject**or** (as opposed to biject**ion**), which is the mathematical term.
312-
As far as I can tell, it's only used in this specific context of transforming probability distributions, and apart from Bijectors.jl itself, it is also used in [the TensorFlow deep learning framework](https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors).
313-
314322
Specifically, one of the primary purposes of Bijectors.jl is used to construct _bijections which map constrained distributions to unconstrained ones_.
315-
For example, the log-normal distribution which we saw above is constrained: its _support_, i.e. the range over which $p(x) \geq 0$, is (0, $\infty$).
323+
For example, the log-normal distribution which we saw above is constrained: its _support_, i.e. the range over which $p(x) \geq 0$, is $(0, \infty)$.
316324
However, we can transform that to an unconstrained distribution (the normal distribution) using the transformation $y = \log(x)$.
325+
317326
The `bijector` function, when applied to a distribution, returns a bijection $f$ that can be used to map the constrained distribution to an unconstrained one.
318327

319328
```{julia}
@@ -414,10 +423,219 @@ B.logpdf_with_trans(LogNormal(), x, true)
414423

415424
## Why is this useful for sampling anyway?
416425

417-
Constrained vs unconstrained variables, sampling, etc.
426+
Constraints pose a problem for pretty much any kind of numerical method, and sampling is no exception to this.
427+
The problem is that for any value $x$ outside of the support of a constrained distribution, $p(x)$ will be zero, and the logpdf will be $-\infty$.
428+
Thus, any term that involves some ratio of probabilities (or equivalently, the logpdf) will be infinite.
429+
430+
::: {.callout-note}
431+
This post is already really long, and does not have quite enough space to explain either the Metropolis–Hastings or Hamiltonian Monte Carlo algorithms in detail.
432+
If you need more information on these, please read e.g. chapter 11 of Bishop.
433+
:::
434+
435+
### Metropolis–Hastings... fine?
436+
437+
This alone is not enough to cause issues for Metropolis–Hastings.
438+
Here's an extremely barebones implementation of a random walk Metropolis algorithm:
439+
440+
```{julia}
441+
# Take a step where the proposal is a normal distribution centred around
442+
# the current value
443+
function mh_step(p, x)
444+
x_proposed = rand(Normal(x, 1))
445+
acceptance_prob = min(1, p(x_proposed) / p(x))
446+
return if rand() < acceptance_prob
447+
x_proposed
448+
else
449+
x
450+
end
451+
end
452+
453+
# Run a random walk Metropolis sampler.
454+
# `p` : a function that takes `x` and returns the pdf of the distribution
455+
# we're trying to sample from
456+
# `x0` : the initial state
457+
function mh(p, x0, n_samples)
458+
samples = []
459+
x = x0
460+
for _ in 2:n_samples
461+
x = mh_step(p, x)
462+
push!(samples, x)
463+
end
464+
return samples
465+
end
466+
```
467+
468+
With this we can sample from a log-normal distribution just fine:
469+
470+
```{julia}
471+
p(x) = pdf(LogNormal(), x)
472+
samples_with_mh = mh(p, 1.0, 5000)
473+
histogram(samples_with_mh, bins=0:0.1:5; xlims=(0, 5))
474+
```
475+
476+
In this MH implementation, the only place where $p(x)$ comes into play is in the acceptance probability.
477+
Since we make sure to start the sampling at a point within the support of the distribution, `p(x)` will be nonzero.
478+
479+
If the proposal step causes `x_proposal` to be outside the support, then `p(x_proposal)` will be zero, and the acceptance probability (`p(x_proposal)/p(x)`) will be zero.
480+
So such a step will never be accepted, and the sampler will continue to stay within the support of the distribution.
481+
Although this does mean that we may find ourselves having a higher reject rate than usual, and thus less efficient sampling, it at least does not cause the algorithm to become unstable or crash.
482+
483+
### Hamiltonian Monte Carlo... not so fine
484+
485+
The _real_ problem comes with gradient-based methods like Hamiltonian Monte Carlo (HMC).
486+
Here's an equally barebones implementation of HMC.
487+
488+
```{julia}
489+
using LinearAlgebra: I
490+
import ForwardDiff
491+
492+
# Really basic leapfrog integrator.
493+
# `z` : position
494+
# `r` : momentum
495+
# `timestep` : size of one integration step
496+
# `nsteps` : number of integration steps
497+
# `dEdz` : function that returns the derivative of the energy with respect
498+
# to `z`. The energy is the negative logpdf of the distribution
499+
# we're trying to sample from.
500+
function leapfrog(z, r, timestep, nsteps, dEdz)
501+
function step_inner(z, r)
502+
# One small step for r, one giant leap for z
503+
r -= (timestep / 2) * dEdz(z)
504+
z += timestep * r
505+
# (and then one more small step for r)
506+
r -= (timestep / 2) * dEdz(z)
507+
return (z, r)
508+
end
509+
for _ in 1:nsteps
510+
z, r = step_inner(z, r)
511+
end
512+
(isnan(z) || isnan(r)) && error("Numerical instability encountered in leapfrog")
513+
return (z, -r)
514+
end
515+
516+
# Take one HMC step.
517+
# `z` : current position
518+
# `E` : function that returns the energy (negative logpdf) at `z`
519+
# Other arguments are as above
520+
function hmc_step(z, E, dEdz, integ_timestep, integ_nsteps)
521+
# Generate new momentum
522+
r = randn()
523+
# Integrate the Hamiltonian dynamics
524+
z_new, r_new = leapfrog(z, r, integ_timestep, integ_nsteps, dEdz)
525+
# Calculate Hamiltonian
526+
H = E(z) + 0.5 * sum(r .^ 2)
527+
H_new = E(z_new) + 0.5 * sum(r_new .^ 2)
528+
# Acceptance criterion
529+
accept_prob = min(1, exp(H - H_new))
530+
return if rand() < accept_prob
531+
z_new
532+
else
533+
z
534+
end
535+
end
536+
537+
# Run HMC.
538+
# `z0` : initial position
539+
# Other arguments are as above
540+
function hmc(z0, E, dEdz, nsteps; integ_timestep=0.1, integ_nsteps=100)
541+
samples = [z0]
542+
z = z0
543+
for _ in 2:nsteps
544+
z = hmc_step(z, E, dEdz, integ_timestep, integ_nsteps)
545+
push!(samples, z)
546+
end
547+
return samples
548+
end
549+
```
550+
551+
Okay, that's our HMC set up.
552+
Now, let's try to sample from a log-normal distribution:
553+
554+
```{julia}
555+
#| error: true
556+
p(x) = pdf(LogNormal(), x)
557+
E(x) = -log(p(x))
558+
dEdz(x) = ForwardDiff.derivative(E, x)
559+
samples_with_hmc = hmc(1.0, E, dEdz, 5000)
560+
histogram(samples_with_hmc, bins=0:0.1:5; xlims=(0, 5))
561+
```
562+
563+
Eeeek! What happened?
564+
It turns out that evaluating the gradient of the energy at any point outside the support of the distribution is not possible:
565+
566+
```{julia}
567+
dEdz(-1)
568+
```
569+
570+
This is because $p(x)$ is 0, and hence $E(x) = -\log(p(x))$ is $\infty$ outside the support.
571+
If we try to evaluate the gradient at such a point, it's simply undefined, because arithmetic on infinity doesn't make sense:
572+
573+
```{julia}
574+
Inf - Inf
575+
```
576+
577+
To really pinpoint where this is happening, we need to look into the HMC leapfrog integration, specifically these lines:
578+
579+
```julia
580+
r -= (timestep / 2) * dEdz(z) # (1)
581+
z += timestep * r # (2)
582+
r -= (timestep / 2) * dEdz(z) # (3)
583+
```
584+
585+
Here, `z` is the position.
586+
Since we start our sampler inside the support of the distribution (by supplying a good initial point), `dEdz(z)` will start off being well-defined on line (1).
587+
However, after `r` is updated on line (1), `z` is updated again on line (2), and _this_ value of `z` may well be outside of the support.
588+
At this point, `dEdz(z)` will be `NaN`, and the final update to `r` on line (3) will also cause it to be `NaN`.
589+
590+
Even if we're lucky enough for an individual integration step to not move `z` outside the support, there are many integration steps per sampler step, and many sampler steps, and so the chances of this happening at some point are quite high.
591+
592+
It's possible to choose your integration parameters carefully to reduce the risk of this happening.
593+
For example, we could set the integration timestep to be _really_ small, thus reducing the chance of making a move outside the support.
594+
But that will just lead to a very slow exploration of parameter space, and in general, we should like to avoid this problem altogether.
595+
596+
### Rescuing HMC
597+
598+
Perhaps unsurprisingly, the answer to this is to transform the underlying distribution to an unconstrained one and sample from that instead.
599+
However, to preserve the correct behaviour, we have to make sure that we include the pesky Jacobian term when sampling from the transformed distribution.
600+
Bijectors.jl can do all of this for us.
601+
602+
The main thing we need to do is to pass a modified version of the function `p` to our HMC sampler.
603+
Recall the problem is that our `p` is zero outside the support of the distribution.
604+
What we can do is to instead specify `p` as the pdf of our transformed distribution, evaluated at the transformed value of `x` (which we'll call `y`).
605+
606+
```{julia}
607+
d = LogNormal()
608+
# Calling pdf() on a transformed distribution automatically includes
609+
# the Jacobian term
610+
p_transformed(y) = pdf(B.transformed(d), y)
611+
# These definitions are the same as before
612+
E(z) = -log(p_transformed(z))
613+
dEdz(z) = ForwardDiff.derivative(E, z)
614+
```
615+
616+
When we run HMC on this, it will give us back samples of `y`, not `x`.
617+
So we can untransform them, and voilà, our HMC sampler works again!
618+
619+
```{julia}
620+
samples_with_hmc = hmc(1.0, E, dEdz, 5000)
621+
622+
bijector = B.bijector(d)
623+
samples_with_hmc_untransformed = B.inverse(bijector)(samples_with_hmc)
624+
histogram(samples_with_hmc_untransformed, bins=0:0.1:5; xlims=(0, 5))
625+
```
626+
418627

419628
## How does DynamicPPL use bijectors?
420629

421-
link, invlink, transform, varinfo etc.
630+
In the final section of this article, we'll discuss the higher-level implications of constrained distributions in the Turing.jl framework.
631+
632+
When we are performing Bayesian inference, we're trying to sample from a joint probability distribution, which isn't usually a single, well-defined distribution like in the rather simplified example above.
633+
However, each random variable in the model will have its own distribution, and often some of these will be constrained.
634+
For example, if `b ~ LogNormal()` is a random variable in a model, then $p(b)$ will be zero for any $b \leq 0$.
635+
Consequently, any joint probability $p(b, c, \ldots)$ will also be zero for any combination of parameters where $b \leq 0$, and so that joint distribution is itself constrained.
636+
637+
TODO: Talk about varinfo internals here I think.
638+
It's all in `src/abstract_varinfo.jl`.
639+
Unfortunately I probably need another few more days (at least) to understand this properly.
422640

423641
See [https://turinglang.org/DynamicPPL.jl/stable/internals/transformations/](https://turinglang.org/DynamicPPL.jl/stable/internals/transformations/)

0 commit comments

Comments
 (0)