You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: src/transforms.qmd
+140-4
Original file line number
Diff line number
Diff line change
@@ -699,8 +699,144 @@ However, each random variable in the model will have its own distribution, and o
699
699
For example, if `b ~ LogNormal()` is a random variable in a model, then $p(b)$ will be zero for any $b \leq 0$.
700
700
Consequently, any joint probability $p(b, c, \ldots)$ will also be zero for any combination of parameters where $b \leq 0$, and so that joint distribution is itself constrained.
701
701
702
-
TODO: Talk about varinfo internals here I think.
703
-
It's all in `src/abstract_varinfo.jl`.
704
-
Unfortunately I probably need another few more days (at least) to understand this properly.
702
+
To get around this, DynamicPPL allows the variables to be transformed in exactly the same way as above.
703
+
For simplicity, consider the following model:
705
704
706
-
See [https://turinglang.org/DynamicPPL.jl/stable/internals/transformations/](https://turinglang.org/DynamicPPL.jl/stable/internals/transformations/)
705
+
```{julia}
706
+
using DynamicPPL
707
+
708
+
@model function demo()
709
+
x ~ LogNormal()
710
+
end
711
+
712
+
model = demo()
713
+
vi = VarInfo(model)
714
+
vn_x = @varname(x)
715
+
# Retrieve the 'internal' value of x – we'll explain this later
716
+
DynamicPPL.getindex_internal(vi, vn_x)
717
+
```
718
+
719
+
The call to `VarInfo` executes the model once and stores the sampled value inside `vi`.
720
+
By default, `VarInfo` itself stores un-transformed values.
721
+
We can see this by comparing the value of the logpdf stored inside the `VarInfo`:
In DynamicPPL, the `link!!` function can be used to transform the variables.
734
+
These functions do three things: firstly, they transform the variables; secondly, they update the value of logp (by adding the Jacobian term); and thirdly, they set a flag on the variables to indicate that it has been transformed.
735
+
Note that these functions act on _all_ variables in the model, including unconstrained ones.
736
+
(Unconstrained variables just have an identity transformation.)
We can see that there are _two_ differences between these outputs:
775
+
776
+
1._The internal value has been transformed using the bijector (in this case, the log function)._
777
+
This means that the `istrans()` flag which we used above doesn't tell us anything about whether the 'external' value has been transformed: it only tells us about the internal value.
778
+
779
+
2._The internal value is a vector, whereas the value is a scalar._
780
+
This is because _all_ internal values are vectorised (i.e. converted into some vector), regardless of distribution.
Essentially, the value is the one which the user 'expects' to see based on the model definition.
789
+
The 'internal' value is one that is the most convenient representation to work with inside DynamicPPL.
790
+
791
+
It also means that internally, the transformation in `link!!` is carried out in three steps:
792
+
793
+
1. Un-vectorise the internal value.
794
+
2. Apply the transformation.
795
+
3. Vectorise the transformed value.
796
+
797
+
The actual implementation is slightly harder to parse as it has to work for different flavours of `VarInfo`, but it eventually boils down to the following:
798
+
799
+
```{julia}
800
+
invlink!!(vi, model) # Reset to un-transformed state
801
+
dist = DynamicPPL.getdist(vi, vn_x)
802
+
x_val = DynamicPPL.getindex_internal(vi, vn_x)
803
+
```
804
+
805
+
```{julia}
806
+
# Step 1: un-vectorise
807
+
fn1 = DynamicPPL.from_vec_transform(dist)
808
+
fn1(x_val)
809
+
```
810
+
811
+
```{julia}
812
+
# Step 2: transform
813
+
# DynamicPPL.link_transform(dist) is really Bijectors.bijector(dist)
814
+
fn2 = DynamicPPL.link_transform(dist)
815
+
fn2(fn1(x_val))
816
+
```
817
+
818
+
```{julia}
819
+
# Step 3.: re-vectorise
820
+
fn3 = DynamicPPL.to_vec_transform(dist)
821
+
fn3(fn2(fn1(x_val)))
822
+
```
823
+
824
+
### So when does the transformation actually happen?
825
+
826
+
TODO
827
+
828
+
... see HMC implementation in Turing
829
+
830
+
... logdensity evaluation on a LogDensityFunction -> `evaluate!!`
831
+
832
+
... tilde pipeline
833
+
834
+
e.g. `assume` for HMC is here https://github.com/TuringLang/Turing.jl/blob/5b24cebe773922e0f3d5c4cb7f53162eb758b04d/src/mcmc/hmc.jl#L493C1-L498C4
835
+
836
+
which goes to the default `assume` implementation https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/context_implementations.jl#L225-L229
837
+
838
+
which leads to `invlink_with_logpdf`https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/abstract_varinfo.jl#L773-L792
839
+
840
+
which returns the UNTRANSFORMED value (important to explain why here – it's because the return value is assigned to the variable in the model, which the user can see) and the appropriately calculated logpdf, depending on whether `istrans(vi, vn)` returns true
0 commit comments