You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In DynamicPPL, the `link!!` function can be used to transform the variables.
734
-
These functions do three things: firstly, they transform the variables; secondly, they update the value of logp (by adding the Jacobian term); and thirdly, they set a flag on the variables to indicate that it has been transformed.
735
-
Note that these functions act on _all_ variables in the model, including unconstrained ones.
730
+
In DynamicPPL, the `link` function can be used to transform the variables.
731
+
This function does three things: firstly, it transforms the variables; secondly, it updates the value of logp (by adding the Jacobian term); and thirdly, it sets a flag on the variables to indicate that it has been transformed.
732
+
Note that this acts on _all_ variables in the model, including unconstrained ones.
736
733
(Unconstrained variables just have an identity transformation.)
We can see that there are _two_ differences between these outputs:
768
+
We can see (for the linked varinfo) that there are _two_ differences between these outputs:
775
769
776
770
1._The internal value has been transformed using the bijector (in this case, the log function)._
777
771
This means that the `istrans()` flag which we used above doesn't tell us anything about whether the 'external' value has been transformed: it only tells us about the internal value.
@@ -788,16 +782,16 @@ We can see that there are _two_ differences between these outputs:
788
782
Essentially, the value is the one which the user 'expects' to see based on the model definition.
789
783
The 'internal' value is one that is the most convenient representation to work with inside DynamicPPL.
790
784
791
-
It also means that internally, the transformation in `link!!` is carried out in three steps:
785
+
It also means that internally, the transformation in `link` is carried out in three steps:
792
786
793
787
1. Un-vectorise the internal value.
794
788
2. Apply the transformation.
795
789
3. Vectorise the transformed value.
796
790
797
-
The actual implementation is slightly harder to parse as it has to work for different flavours of `VarInfo`, but it eventually boils down to the following:
791
+
The actual implementation is slightly harder to parse as it has to work for different flavours of `VarInfo`, but it eventually boils down to the following (see the implementation [here](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/varinfo.jl#L1390-L1414)):
798
792
799
793
```{julia}
800
-
invlink!!(vi, model) # Reset to un-transformed state
### So when does the transformation actually happen?
818
+
## Sampling in Turing.jl
819
+
820
+
DynamicPPL provides the _functionality_ for transforming variables, but the transformation itself happens at an even higher level, i.e. in the sampler itself.
821
+
For example, consider the HMC sampler in Turing.jl, which is in [this file](https://github.com/TuringLang/Turing.jl/blob/5b24cebe773922e0f3d5c4cb7f53162eb758b04d/src/mcmc/hmc.jl).
822
+
In the first step of sampling, it calls `link` on the sampler.
823
+
This transformation is preserved throughout the sampling process, meaning that `istrans()` always returns true.
825
824
826
-
TODO
825
+
We can observe this by inserting print statements into the model.
826
+
Here, `__varinfo__` is the internal symbol for the `VarInfo` object used in model evaluation:
... logdensity evaluation on a LogDensityFunction -> `evaluate!!`
839
+
sample(demo2(), HMC(0.1, 3), 3);
840
+
```
831
841
832
-
... tilde pipeline
842
+
(Here, the check on `if x isa Float64` prevents the printing from occurring during computation of the derivative.)
843
+
You can see that during the actual sampling, `istrans` is always kept as `true`.
833
844
834
-
e.g. `assume` for HMC is here https://github.com/TuringLang/Turing.jl/blob/5b24cebe773922e0f3d5c4cb7f53162eb758b04d/src/mcmc/hmc.jl#L493C1-L498C4
845
+
::: {.callout-note}
846
+
The first two model evaluations where `istrans` is `false` occur prior to the actual sampling.
847
+
One occurs when the model is checked for correctness (using [`DynamicPPL.check_model_and_trace`](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/debug_utils.jl#L582-L612)).
848
+
The second occurs because the model is evaluated once to generate a set of initial parameters inside [DynamicPPL's implementation of `AbstractMCMC.step`](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/sampler.jl#L98-L117).
849
+
Both of these steps occur with all samplers in Turing.jl.
850
+
:::
835
851
836
-
which goes to the default `assume` implementation https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/context_implementations.jl#L225-L229
852
+
What this means is that from the perspective of the HMC sampler, it _never_ sees the constrained variable: it always thinks that it is sampling from an unconstrained distribution.
837
853
838
-
which leads to `invlink_with_logpdf`https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/abstract_varinfo.jl#L773-L792
854
+
The biggest prerequisite for this to work correctly is that the potential energy term in the Hamiltonian—or in other words, the model log density—must be programmed correctly to include the Jacobian term.
855
+
This is exactly the same as how we had to make sure to define `logq(y)` correctly in the toy HMC example above.
839
856
840
-
which returns the UNTRANSFORMED value (important to explain why here – it's because the return value is assigned to the variable in the model, which the user can see) and the appropriately calculated logpdf, depending on whether `istrans(vi, vn)` returns true
857
+
This occurs correctly because a statement like `x ~ LogNormal()` in the model definition above is translated into `assume(LogNormal(), @varname(x), __varinfo__)`, defined [here](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/context_implementations.jl#L225-L229).
858
+
As can be seen by following through on the definition of `invlink_with_logpdf`, this does indeed checks for the presence of the `istrans` flag and adds the Jacobian accordingly.
841
859
842
-
TODO: Understand `maybe_invlink_before_eval`.
860
+
::: {.callout-note}
861
+
The discussion above skips over several steps in the Turing.jl codebase, which can be difficult to follow.
862
+
Specifically:
863
+
864
+
1. Samplers such as HMC [wrap Turing models in a `DynamicPPL.LogDensityFunction`](https://github.com/TuringLang/Turing.jl/blob/5b24cebe773922e0f3d5c4cb7f53162eb758b04d/src/mcmc/hmc.jl#L159-L168).
865
+
2. The log density at a given set of parameter values can then be [calculated using `logdensity`](https://github.com/TuringLang/DynamicPPL.jl/blob/ba490bf362653e1aaefe298364fe3379b60660d3/src/logdensityfunction.jl#L136-L141)
866
+
3. This in turn calls `evaluate!!`, which runs the _model evaluator function_. This evaluator function is not visible in the DynamicPPL codebase because it is generated by the expansion of the `@model` macro. You can see it, though, by running:
867
+
```julia
868
+
@macroexpand@modeldemo3() = x ~LogNormal()
869
+
```
870
+
Note that these evaluations do not trigger the print statements in the model because it is run using automatic differentiation (in this case, `x` is a `ForwardDiff.Dual`).
0 commit comments