Skip to content

Commit 6025893

Browse files
committed
Almost done with transforms _sigh_
1 parent 7c7a6f6 commit 6025893

File tree

1 file changed

+140
-4
lines changed

1 file changed

+140
-4
lines changed

src/transforms.qmd

+140-4
Original file line numberDiff line numberDiff line change
@@ -699,8 +699,144 @@ However, each random variable in the model will have its own distribution, and o
699699
For example, if `b ~ LogNormal()` is a random variable in a model, then $p(b)$ will be zero for any $b \leq 0$.
700700
Consequently, any joint probability $p(b, c, \ldots)$ will also be zero for any combination of parameters where $b \leq 0$, and so that joint distribution is itself constrained.
701701

702-
TODO: Talk about varinfo internals here I think.
703-
It's all in `src/abstract_varinfo.jl`.
704-
Unfortunately I probably need another few more days (at least) to understand this properly.
702+
To get around this, DynamicPPL allows the variables to be transformed in exactly the same way as above.
703+
For simplicity, consider the following model:
705704

706-
See [https://turinglang.org/DynamicPPL.jl/stable/internals/transformations/](https://turinglang.org/DynamicPPL.jl/stable/internals/transformations/)
705+
```{julia}
706+
using DynamicPPL
707+
708+
@model function demo()
709+
x ~ LogNormal()
710+
end
711+
712+
model = demo()
713+
vi = VarInfo(model)
714+
vn_x = @varname(x)
715+
# Retrieve the 'internal' value of x – we'll explain this later
716+
DynamicPPL.getindex_internal(vi, vn_x)
717+
```
718+
719+
The call to `VarInfo` executes the model once and stores the sampled value inside `vi`.
720+
By default, `VarInfo` itself stores un-transformed values.
721+
We can see this by comparing the value of the logpdf stored inside the `VarInfo`:
722+
723+
```{julia}
724+
DynamicPPL.getlogp(vi)
725+
```
726+
727+
with a manual calculation:
728+
729+
```{julia}
730+
logpdf(LogNormal(), DynamicPPL.getindex_internal(vi, vn_x))
731+
```
732+
733+
In DynamicPPL, the `link!!` function can be used to transform the variables.
734+
These functions do three things: firstly, they transform the variables; secondly, they update the value of logp (by adding the Jacobian term); and thirdly, they set a flag on the variables to indicate that it has been transformed.
735+
Note that these functions act on _all_ variables in the model, including unconstrained ones.
736+
(Unconstrained variables just have an identity transformation.)
737+
738+
```{julia}
739+
DynamicPPL.link!!(vi, model)
740+
println("Transformed value: $(DynamicPPL.getindex_internal(vi, vn_x))")
741+
println("Transformed logp: $(DynamicPPL.getlogp(vi))")
742+
println("Transformed flag: $(DynamicPPL.istrans(vi, vn_x))")
743+
```
744+
745+
Indeed, we can see that the new logp value matches with
746+
747+
```{julia}
748+
logpdf(Normal(), DynamicPPL.getindex_internal(vi, vn_x))
749+
```
750+
751+
The reverse transformation, `invlink!!`, reverts all of the above steps:
752+
753+
```{julia}
754+
DynamicPPL.invlink!!(vi, model)
755+
println("Un-transformed value: $(DynamicPPL.getindex_internal(vi, vn_x))")
756+
println("Un-transformed logp: $(DynamicPPL.getlogp(vi))")
757+
println("Un-transformed flag: $(DynamicPPL.istrans(vi, vn_x))")
758+
```
759+
760+
### Values and 'internal' values
761+
762+
In DynamicPPL, there is a difference between the value of a random variable and its 'internal' value.
763+
This is most easily seen by first transforming, and then comparing the output of `getindex` and `getindex_internal`.
764+
The former extracts the regular value, whereas (as the name suggests) the latter gets the 'internal' value.
765+
766+
```{julia}
767+
# Transform
768+
DynamicPPL.link!!(vi, model)
769+
770+
println("Value: $(getindex(vi, vn_x))") # same as `vi[vn_x]`
771+
println("Internal value: $(DynamicPPL.getindex_internal(vi, vn_x))")
772+
```
773+
774+
We can see that there are _two_ differences between these outputs:
775+
776+
1. _The internal value has been transformed using the bijector (in this case, the log function)._
777+
This means that the `istrans()` flag which we used above doesn't tell us anything about whether the 'external' value has been transformed: it only tells us about the internal value.
778+
779+
2. _The internal value is a vector, whereas the value is a scalar._
780+
This is because _all_ internal values are vectorised (i.e. converted into some vector), regardless of distribution.
781+
782+
| Distribution | Value | Internal value |
783+
| --- | --- | --- |
784+
| Univariate (e.g. `Normal()`) | Scalar | Length-1 vector, possibly transformed |
785+
| Multivariate (e.g. `MvNormal()`) | Vector | Vector, possibly transformed |
786+
| Matrixvariate (e.g. `Wishart()`) | Matrix | Vectorised matrix, possibly transformed |
787+
788+
Essentially, the value is the one which the user 'expects' to see based on the model definition.
789+
The 'internal' value is one that is the most convenient representation to work with inside DynamicPPL.
790+
791+
It also means that internally, the transformation in `link!!` is carried out in three steps:
792+
793+
1. Un-vectorise the internal value.
794+
2. Apply the transformation.
795+
3. Vectorise the transformed value.
796+
797+
The actual implementation is slightly harder to parse as it has to work for different flavours of `VarInfo`, but it eventually boils down to the following:
798+
799+
```{julia}
800+
invlink!!(vi, model) # Reset to un-transformed state
801+
dist = DynamicPPL.getdist(vi, vn_x)
802+
x_val = DynamicPPL.getindex_internal(vi, vn_x)
803+
```
804+
805+
```{julia}
806+
# Step 1: un-vectorise
807+
fn1 = DynamicPPL.from_vec_transform(dist)
808+
fn1(x_val)
809+
```
810+
811+
```{julia}
812+
# Step 2: transform
813+
# DynamicPPL.link_transform(dist) is really Bijectors.bijector(dist)
814+
fn2 = DynamicPPL.link_transform(dist)
815+
fn2(fn1(x_val))
816+
```
817+
818+
```{julia}
819+
# Step 3.: re-vectorise
820+
fn3 = DynamicPPL.to_vec_transform(dist)
821+
fn3(fn2(fn1(x_val)))
822+
```
823+
824+
### So when does the transformation actually happen?
825+
826+
TODO
827+
828+
... see HMC implementation in Turing
829+
830+
... logdensity evaluation on a LogDensityFunction -> `evaluate!!`
831+
832+
... tilde pipeline
833+
834+
e.g. `assume` for HMC is here https://github.com/TuringLang/Turing.jl/blob/5b24cebe773922e0f3d5c4cb7f53162eb758b04d/src/mcmc/hmc.jl#L493C1-L498C4
835+
836+
which goes to the default `assume` implementation https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/context_implementations.jl#L225-L229
837+
838+
which leads to `invlink_with_logpdf` https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/abstract_varinfo.jl#L773-L792
839+
840+
which returns the UNTRANSFORMED value (important to explain why here – it's because the return value is assigned to the variable in the model, which the user can see) and the appropriately calculated logpdf, depending on whether `istrans(vi, vn)` returns true
841+
842+
TODO: Understand `maybe_invlink_before_eval`.

0 commit comments

Comments
 (0)