Skip to content

Commit 02f0ba0

Browse files
committed
I think I'm done
1 parent 6025893 commit 02f0ba0

File tree

1 file changed

+66
-32
lines changed

1 file changed

+66
-32
lines changed

src/transforms.qmd

+66-32
Original file line numberDiff line numberDiff line change
@@ -624,9 +624,6 @@ E(z) = -logq(z)
624624
dEdz(z) = ForwardDiff.derivative(E, z)
625625
```
626626

627-
The `exp`/`log` wrapping is a bit awkward.
628-
In practice we would only ever work on the log scale, but
629-
630627
Now, because our transformed distribution is unconstrained, we can evaluate `E` and `dEdz` at any point, and sample with more confidence:
631628

632629
```{julia}
@@ -730,28 +727,28 @@ with a manual calculation:
730727
logpdf(LogNormal(), DynamicPPL.getindex_internal(vi, vn_x))
731728
```
732729

733-
In DynamicPPL, the `link!!` function can be used to transform the variables.
734-
These functions do three things: firstly, they transform the variables; secondly, they update the value of logp (by adding the Jacobian term); and thirdly, they set a flag on the variables to indicate that it has been transformed.
735-
Note that these functions act on _all_ variables in the model, including unconstrained ones.
730+
In DynamicPPL, the `link` function can be used to transform the variables.
731+
This function does three things: firstly, it transforms the variables; secondly, it updates the value of logp (by adding the Jacobian term); and thirdly, it sets a flag on the variables to indicate that it has been transformed.
732+
Note that this acts on _all_ variables in the model, including unconstrained ones.
736733
(Unconstrained variables just have an identity transformation.)
737734

738735
```{julia}
739-
DynamicPPL.link!!(vi, model)
740-
println("Transformed value: $(DynamicPPL.getindex_internal(vi, vn_x))")
741-
println("Transformed logp: $(DynamicPPL.getlogp(vi))")
742-
println("Transformed flag: $(DynamicPPL.istrans(vi, vn_x))")
736+
vi_linked = DynamicPPL.link(vi, model)
737+
println("Transformed value: $(DynamicPPL.getindex_internal(vi_linked, vn_x))")
738+
println("Transformed logp: $(DynamicPPL.getlogp(vi_linked))")
739+
println("Transformed flag: $(DynamicPPL.istrans(vi_linked, vn_x))")
743740
```
744741

745742
Indeed, we can see that the new logp value matches with
746743

747744
```{julia}
748-
logpdf(Normal(), DynamicPPL.getindex_internal(vi, vn_x))
745+
logpdf(Normal(), DynamicPPL.getindex_internal(vi_linked, vn_x))
749746
```
750747

751-
The reverse transformation, `invlink!!`, reverts all of the above steps:
748+
The reverse transformation, `invlink`, reverts all of the above steps:
752749

753750
```{julia}
754-
DynamicPPL.invlink!!(vi, model)
751+
vi = DynamicPPL.invlink(vi_linked, model) # Same as the previous vi
755752
println("Un-transformed value: $(DynamicPPL.getindex_internal(vi, vn_x))")
756753
println("Un-transformed logp: $(DynamicPPL.getlogp(vi))")
757754
println("Un-transformed flag: $(DynamicPPL.istrans(vi, vn_x))")
@@ -764,14 +761,11 @@ This is most easily seen by first transforming, and then comparing the output of
764761
The former extracts the regular value, whereas (as the name suggests) the latter gets the 'internal' value.
765762

766763
```{julia}
767-
# Transform
768-
DynamicPPL.link!!(vi, model)
769-
770-
println("Value: $(getindex(vi, vn_x))") # same as `vi[vn_x]`
771-
println("Internal value: $(DynamicPPL.getindex_internal(vi, vn_x))")
764+
println("Value: $(getindex(vi_linked, vn_x))") # same as `vi_linked[vn_x]`
765+
println("Internal value: $(DynamicPPL.getindex_internal(vi_linked, vn_x))")
772766
```
773767

774-
We can see that there are _two_ differences between these outputs:
768+
We can see (for the linked varinfo) that there are _two_ differences between these outputs:
775769

776770
1. _The internal value has been transformed using the bijector (in this case, the log function)._
777771
This means that the `istrans()` flag which we used above doesn't tell us anything about whether the 'external' value has been transformed: it only tells us about the internal value.
@@ -788,16 +782,16 @@ We can see that there are _two_ differences between these outputs:
788782
Essentially, the value is the one which the user 'expects' to see based on the model definition.
789783
The 'internal' value is one that is the most convenient representation to work with inside DynamicPPL.
790784

791-
It also means that internally, the transformation in `link!!` is carried out in three steps:
785+
It also means that internally, the transformation in `link` is carried out in three steps:
792786

793787
1. Un-vectorise the internal value.
794788
2. Apply the transformation.
795789
3. Vectorise the transformed value.
796790

797-
The actual implementation is slightly harder to parse as it has to work for different flavours of `VarInfo`, but it eventually boils down to the following:
791+
The actual implementation is slightly harder to parse as it has to work for different flavours of `VarInfo`, but it eventually boils down to the following (see the implementation [here](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/varinfo.jl#L1390-L1414)):
798792

799793
```{julia}
800-
invlink!!(vi, model) # Reset to un-transformed state
794+
# Use the un-linked varinfo
801795
dist = DynamicPPL.getdist(vi, vn_x)
802796
x_val = DynamicPPL.getindex_internal(vi, vn_x)
803797
```
@@ -821,22 +815,62 @@ fn3 = DynamicPPL.to_vec_transform(dist)
821815
fn3(fn2(fn1(x_val)))
822816
```
823817

824-
### So when does the transformation actually happen?
818+
## Sampling in Turing.jl
819+
820+
DynamicPPL provides the _functionality_ for transforming variables, but the transformation itself happens at an even higher level, i.e. in the sampler itself.
821+
For example, consider the HMC sampler in Turing.jl, which is in [this file](https://github.com/TuringLang/Turing.jl/blob/5b24cebe773922e0f3d5c4cb7f53162eb758b04d/src/mcmc/hmc.jl).
822+
In the first step of sampling, it calls `link` on the sampler.
823+
This transformation is preserved throughout the sampling process, meaning that `istrans()` always returns true.
825824

826-
TODO
825+
We can observe this by inserting print statements into the model.
826+
Here, `__varinfo__` is the internal symbol for the `VarInfo` object used in model evaluation:
827827

828-
... see HMC implementation in Turing
828+
```{julia}
829+
@model function demo2()
830+
x ~ LogNormal()
831+
if x isa Float64
832+
println("-----------")
833+
println("value: $x")
834+
println("internal value: $(DynamicPPL.getindex_internal(__varinfo__, @varname(x)))")
835+
println("istrans: $(istrans(__varinfo__, @varname(x)))")
836+
end
837+
end
829838
830-
... logdensity evaluation on a LogDensityFunction -> `evaluate!!`
839+
sample(demo2(), HMC(0.1, 3), 3);
840+
```
831841

832-
... tilde pipeline
842+
(Here, the check on `if x isa Float64` prevents the printing from occurring during computation of the derivative.)
843+
You can see that during the actual sampling, `istrans` is always kept as `true`.
833844

834-
e.g. `assume` for HMC is here https://github.com/TuringLang/Turing.jl/blob/5b24cebe773922e0f3d5c4cb7f53162eb758b04d/src/mcmc/hmc.jl#L493C1-L498C4
845+
::: {.callout-note}
846+
The first two model evaluations where `istrans` is `false` occur prior to the actual sampling.
847+
One occurs when the model is checked for correctness (using [`DynamicPPL.check_model_and_trace`](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/debug_utils.jl#L582-L612)).
848+
The second occurs because the model is evaluated once to generate a set of initial parameters inside [DynamicPPL's implementation of `AbstractMCMC.step`](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/sampler.jl#L98-L117).
849+
Both of these steps occur with all samplers in Turing.jl.
850+
:::
835851

836-
which goes to the default `assume` implementation https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/context_implementations.jl#L225-L229
852+
What this means is that from the perspective of the HMC sampler, it _never_ sees the constrained variable: it always thinks that it is sampling from an unconstrained distribution.
837853

838-
which leads to `invlink_with_logpdf` https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/abstract_varinfo.jl#L773-L792
854+
The biggest prerequisite for this to work correctly is that the potential energy term in the Hamiltonian—or in other words, the model log density—must be programmed correctly to include the Jacobian term.
855+
This is exactly the same as how we had to make sure to define `logq(y)` correctly in the toy HMC example above.
839856

840-
which returns the UNTRANSFORMED value (important to explain why here – it's because the return value is assigned to the variable in the model, which the user can see) and the appropriately calculated logpdf, depending on whether `istrans(vi, vn)` returns true
857+
This occurs correctly because a statement like `x ~ LogNormal()` in the model definition above is translated into `assume(LogNormal(), @varname(x), __varinfo__)`, defined [here](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/context_implementations.jl#L225-L229).
858+
As can be seen by following through on the definition of `invlink_with_logpdf`, this does indeed checks for the presence of the `istrans` flag and adds the Jacobian accordingly.
841859

842-
TODO: Understand `maybe_invlink_before_eval`.
860+
::: {.callout-note}
861+
The discussion above skips over several steps in the Turing.jl codebase, which can be difficult to follow.
862+
Specifically:
863+
864+
1. Samplers such as HMC [wrap Turing models in a `DynamicPPL.LogDensityFunction`](https://github.com/TuringLang/Turing.jl/blob/5b24cebe773922e0f3d5c4cb7f53162eb758b04d/src/mcmc/hmc.jl#L159-L168).
865+
2. The log density at a given set of parameter values can then be [calculated using `logdensity`](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/logdensityfunction.jl#L136-L141)
866+
3. This in turn calls `evaluate!!`, which runs the _model evaluator function_. This evaluator function is not visible in the DynamicPPL codebase because it is generated by the expansion of the `@model` macro. You can see it, though, by running:
867+
```julia
868+
@macroexpand @model demo3() = x ~ LogNormal()
869+
```
870+
Note that these evaluations do not trigger the print statements in the model because it is run using automatic differentiation (in this case, `x` is a `ForwardDiff.Dual`).
871+
4. This generates a line which looks like
872+
```julia
873+
(var"##value#441", __varinfo__) = (DynamicPPL.tilde_assume!!)(__context__, (DynamicPPL.unwrap_right_vn)((DynamicPPL.check_tilde_rhs)(var"##dist#440"), var"##vn#437")..., __varinfo__)
874+
```
875+
`tilde_assume!!` in turn calls `tilde_assume`, which ultimately delegates to `assume`.
876+
:::

0 commit comments

Comments
 (0)