You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Now say that $x$ is distributed according to `Normal()`, and we want to draw samples from $y = \exp(x)$.
51
-
The distribution of $y$ is known as a [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution).
50
+
Say that $x$ is distributed according to `Normal()`, and we want to draw samples of $y = \exp(x)$.
51
+
Now, $y$ is itself a random variable, and like any other random variable, will have a probability distribution, which we'll call $q(y)$.
52
52
53
-
For illustration purposes, let's make our own `MyLogNormal` distribution that we can sample from: see Distribution.jl's documentation on custom distributions [here](https://juliastats.org/Distributions.jl/stable/extends/#Univariate-Distribution).
54
-
(Distributions already defines its own `LogNormal`, so we have to use a different name.)
53
+
In this specific case, the distribution of $y$ is known as a [log-normal distribution](https://en.wikipedia.org/wiki/Log-normal_distribution).
54
+
For illustration purposes, let's try to implement our own `MyLogNormal` distribution that we can sample from.
55
+
(Distributions.jl already defines its own `LogNormal`, so we have to use a different name.)
Fundamentally, the reason why this doesn't work is because transforming a (continuous) distribution causes probability density to be stretched and otherwise moved around.
94
98
95
99
::: {.callout-note}
96
-
There are various posts on the Internet that explain this visually; I'm too lazy to draw a diagram _right now_, but I might do it later.
100
+
There are various posts on the Internet that explain this visually.
97
101
:::
98
102
99
-
I personally find it most useful to not talk about probability density itself, but instead to make it more concrete by talking about actual probabilities.
100
-
If we think about the normal distribution as a continuous curve, what the probability density function $p(x)$ really tells us is that for any two points $a$ and $b$ (where $a \leq b$), the probability of drawing a sample from the interval $[a, b]$ is the area under the curve, i.e.
103
+
A perhaps useful approach is to not talk about _probability densities_, but instead to make it more concrete by talking about actual _probabilities_.
104
+
If we think about the normal distribution as a continuous curve, what the probability density function $p(x)$ really tells us is that: for any two points $a$ and $b$ (where $a \leq b$), the probability of drawing a sample between $a$ and $b$ is the corresponding area under the curve, i.e.
This is the same as equation (11.9) in Bishop, except that he denotes the absolute value of the determinant with just $|\mathcal{J}|$.
179
183
180
-
::: {.callout-note}
181
-
In different contexts the Jacobian can have different 'numerators' and 'denominators' in the partial derivatives.
182
-
For example, if $\mathbf{y} = f(\mathbf{x})$, then it's common to write $\mathbf{J}$ as a matrix of partial derivatives of elements of $y$ with respect to elements of $x$.
184
+
::: {.callout-important}
185
+
Note that, if we have a function $f$ mapping $\mathbf{x}$ to $\mathbf{y}$, then the Jacobian matrix $\mathbf{J}$ (sometimes denoted $\mathbf{J}_f$) is usually defined _the other way round_:
Indeed, later in this article we will see that Bijectors.jl uses this convention.
193
+
This is why we have denoted this 'inverse' Jacobian as $\mathcal{J}$, rather than $\mathbf{J}$.
184
194
185
-
It is always the case, though, that the elements of the 'numerator' vary with rows and the elements of the 'denominator' vary with columns.
195
+
$\mathcal{J}$ is really the Jacobian of the inverse function $f^{-1}$.
196
+
As it turns out, the matrix $\mathcal{J}$ is also the inverse of $\mathbf{J}$.
186
197
:::
187
198
188
199
The rest of this section will be devoted to an example to show that this works, and contains some slightly less pretty mathematics.
@@ -301,19 +312,17 @@ Technically, the bijections in Bijectors.jl are functions $f: X \to Y$ for which
301
312
- $f$ is continuously differentiable, i.e. the derivative $\mathrm{d}f(x)/\mathrm{d}x$ exists and is continuous (over the domain of interest $X$);
302
313
- If $f^{-1}: Y \to X$ is the inverse of $f$, then that is also continuously differentiable (over _its_ own domain, i.e. $Y$).
303
314
304
-
These are called diffeomorphisms ([Wikipedia](https://en.wikipedia.org/wiki/Diffeomorphism)).
315
+
The technical mathematical term for this is a diffeomorphism ([Wikipedia](https://en.wikipedia.org/wiki/Diffeomorphism)), but we call them 'bijectors'.
305
316
306
317
When thinking about continuous differentiability, it's important to be conscious of the domains or codomains that we care about.
307
318
For example, taking the inverse function $\log(y)$ from above, its derivative is $1/y$, which is not continuous at $y = 0$.
308
319
However, we specified that the bijection $y = \exp(x)$ maps values of $x \in (-\infty, \infty)$ to $y \in (0, \infty)$, so the point $y = 0$ is not within the domain of the inverse function.
309
320
:::
310
321
311
-
It's not entirely clear to me who first coined the term biject**or** (as opposed to biject**ion**), which is the mathematical term.
312
-
As far as I can tell, it's only used in this specific context of transforming probability distributions, and apart from Bijectors.jl itself, it is also used in [the TensorFlow deep learning framework](https://www.tensorflow.org/probability/api_docs/python/tfp/bijectors).
313
-
314
322
Specifically, one of the primary purposes of Bijectors.jl is used to construct _bijections which map constrained distributions to unconstrained ones_.
315
-
For example, the log-normal distribution which we saw above is constrained: its _support_, i.e. the range over which $p(x) \geq 0$, is (0, $\infty$).
323
+
For example, the log-normal distribution which we saw above is constrained: its _support_, i.e. the range over which $p(x) \geq 0$, is $(0, \infty)$.
316
324
However, we can transform that to an unconstrained distribution (the normal distribution) using the transformation $y = \log(x)$.
325
+
317
326
The `bijector` function, when applied to a distribution, returns a bijection $f$ that can be used to map the constrained distribution to an unconstrained one.
Constrained vs unconstrained variables, sampling, etc.
426
+
Constraints pose a problem for pretty much any kind of numerical method, and sampling is no exception to this.
427
+
The problem is that for any value $x$ outside of the support of a constrained distribution, $p(x)$ will be zero, and the logpdf will be $-\infty$.
428
+
Thus, any term that involves some ratio of probabilities (or equivalently, the logpdf) will be infinite.
429
+
430
+
::: {.callout-note}
431
+
This post is already really long, and does not have quite enough space to explain either the Metropolis–Hastings or Hamiltonian Monte Carlo algorithms in detail.
432
+
If you need more information on these, please read e.g. chapter 11 of Bishop.
433
+
:::
434
+
435
+
### Metropolis–Hastings... fine?
436
+
437
+
This alone is not enough to cause issues for Metropolis–Hastings.
438
+
Here's an extremely barebones implementation of a random walk Metropolis algorithm:
439
+
440
+
```{julia}
441
+
# Take a step where the proposal is a normal distribution centred around
442
+
# the current value
443
+
function mh_step(p, x)
444
+
x_proposed = rand(Normal(x, 1))
445
+
acceptance_prob = min(1, p(x_proposed) / p(x))
446
+
return if rand() < acceptance_prob
447
+
x_proposed
448
+
else
449
+
x
450
+
end
451
+
end
452
+
453
+
# Run a random walk Metropolis sampler.
454
+
# `p` : a function that takes `x` and returns the pdf of the distribution
455
+
# we're trying to sample from
456
+
# `x0` : the initial state
457
+
function mh(p, x0, n_samples)
458
+
samples = []
459
+
x = x0
460
+
for _ in 2:n_samples
461
+
x = mh_step(p, x)
462
+
push!(samples, x)
463
+
end
464
+
return samples
465
+
end
466
+
```
467
+
468
+
With this we can sample from a log-normal distribution just fine:
In this MH implementation, the only place where $p(x)$ comes into play is in the acceptance probability.
477
+
Since we make sure to start the sampling at a point within the support of the distribution, `p(x)` will be nonzero.
478
+
479
+
If the proposal step causes `x_proposal` to be outside the support, then `p(x_proposal)` will be zero, and the acceptance probability (`p(x_proposal)/p(x)`) will be zero.
480
+
So such a step will never be accepted, and the sampler will continue to stay within the support of the distribution.
481
+
Although this does mean that we may find ourselves having a higher reject rate than usual, and thus less efficient sampling, it at least does not cause the algorithm to become unstable or crash.
482
+
483
+
### Hamiltonian Monte Carlo... not so fine
484
+
485
+
The _real_ problem comes with gradient-based methods like Hamiltonian Monte Carlo (HMC).
486
+
Here's an equally barebones implementation of HMC.
487
+
488
+
```{julia}
489
+
using LinearAlgebra: I
490
+
import ForwardDiff
491
+
492
+
# Really basic leapfrog integrator.
493
+
# `z` : position
494
+
# `r` : momentum
495
+
# `timestep` : size of one integration step
496
+
# `nsteps` : number of integration steps
497
+
# `dEdz` : function that returns the derivative of the energy with respect
498
+
# to `z`. The energy is the negative logpdf of the distribution
499
+
# we're trying to sample from.
500
+
function leapfrog(z, r, timestep, nsteps, dEdz)
501
+
function step_inner(z, r)
502
+
# One small step for r, one giant leap for z
503
+
r -= (timestep / 2) * dEdz(z)
504
+
z += timestep * r
505
+
# (and then one more small step for r)
506
+
r -= (timestep / 2) * dEdz(z)
507
+
return (z, r)
508
+
end
509
+
for _ in 1:nsteps
510
+
z, r = step_inner(z, r)
511
+
end
512
+
(isnan(z) || isnan(r)) && error("Numerical instability encountered in leapfrog")
513
+
return (z, -r)
514
+
end
515
+
516
+
# Take one HMC step.
517
+
# `z` : current position
518
+
# `E` : function that returns the energy (negative logpdf) at `z`
519
+
# Other arguments are as above
520
+
function hmc_step(z, E, dEdz, integ_timestep, integ_nsteps)
It turns out that evaluating the gradient of the energy at any point outside the support of the distribution is not possible:
565
+
566
+
```{julia}
567
+
dEdz(-1)
568
+
```
569
+
570
+
This is because $p(x)$ is 0, and hence $E(x) = -\log(p(x))$ is $\infty$ outside the support.
571
+
If we try to evaluate the gradient at such a point, it's simply undefined, because arithmetic on infinity doesn't make sense:
572
+
573
+
```{julia}
574
+
Inf - Inf
575
+
```
576
+
577
+
To really pinpoint where this is happening, we need to look into the HMC leapfrog integration, specifically these lines:
578
+
579
+
```julia
580
+
r -= (timestep /2) *dEdz(z) # (1)
581
+
z += timestep * r # (2)
582
+
r -= (timestep /2) *dEdz(z) # (3)
583
+
```
584
+
585
+
Here, `z` is the position.
586
+
Since we start our sampler inside the support of the distribution (by supplying a good initial point), `dEdz(z)` will start off being well-defined on line (1).
587
+
However, after `r` is updated on line (1), `z` is updated again on line (2), and _this_ value of `z` may well be outside of the support.
588
+
At this point, `dEdz(z)` will be `NaN`, and the final update to `r` on line (3) will also cause it to be `NaN`.
589
+
590
+
Even if we're lucky enough for an individual integration step to not move `z` outside the support, there are many integration steps per sampler step, and many sampler steps, and so the chances of this happening at some point are quite high.
591
+
592
+
It's possible to choose your integration parameters carefully to reduce the risk of this happening.
593
+
For example, we could set the integration timestep to be _really_ small, thus reducing the chance of making a move outside the support.
594
+
But that will just lead to a very slow exploration of parameter space, and in general, we should like to avoid this problem altogether.
595
+
596
+
### Rescuing HMC
597
+
598
+
Perhaps unsurprisingly, the answer to this is to transform the underlying distribution to an unconstrained one and sample from that instead.
599
+
However, to preserve the correct behaviour, we have to make sure that we include the pesky Jacobian term when sampling from the transformed distribution.
600
+
Bijectors.jl can do all of this for us.
601
+
602
+
The main thing we need to do is to pass a modified version of the function `p` to our HMC sampler.
603
+
Recall the problem is that our `p` is zero outside the support of the distribution.
604
+
What we can do is to instead specify `p` as the pdf of our transformed distribution, evaluated at the transformed value of `x` (which we'll call `y`).
605
+
606
+
```{julia}
607
+
d = LogNormal()
608
+
# Calling pdf() on a transformed distribution automatically includes
609
+
# the Jacobian term
610
+
p_transformed(y) = pdf(B.transformed(d), y)
611
+
# These definitions are the same as before
612
+
E(z) = -log(p_transformed(z))
613
+
dEdz(z) = ForwardDiff.derivative(E, z)
614
+
```
615
+
616
+
When we run HMC on this, it will give us back samples of `y`, not `x`.
617
+
So we can untransform them, and voilà, our HMC sampler works again!
In the final section of this article, we'll discuss the higher-level implications of constrained distributions in the Turing.jl framework.
631
+
632
+
When we are performing Bayesian inference, we're trying to sample from a joint probability distribution, which isn't usually a single, well-defined distribution like in the rather simplified example above.
633
+
However, each random variable in the model will have its own distribution, and often some of these will be constrained.
634
+
For example, if `b ~ LogNormal()` is a random variable in a model, then $p(b)$ will be zero for any $b \leq 0$.
635
+
Consequently, any joint probability $p(b, c, \ldots)$ will also be zero for any combination of parameters where $b \leq 0$, and so that joint distribution is itself constrained.
636
+
637
+
TODO: Talk about varinfo internals here I think.
638
+
It's all in `src/abstract_varinfo.jl`.
639
+
Unfortunately I probably need another few more days (at least) to understand this properly.
422
640
423
641
See [https://turinglang.org/DynamicPPL.jl/stable/internals/transformations/](https://turinglang.org/DynamicPPL.jl/stable/internals/transformations/)
0 commit comments