Skip to content

Commit

Permalink
Merge branch 'main' into plot-var
Browse files Browse the repository at this point in the history
  • Loading branch information
jobrachem committed Jan 29, 2025
2 parents 17587c8 + cca4c56 commit a09ae28
Show file tree
Hide file tree
Showing 31 changed files with 1,085 additions and 414 deletions.
63 changes: 17 additions & 46 deletions .github/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,16 @@ import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd
import liesel.model as lsl

loc = lsl.param(0.0, name="loc")
scale = lsl.param(1.0, name="scale")
loc = lsl.Var.new_param(0.0, name="loc")
scale = lsl.Var.new_param(1.0, name="scale")

y = lsl.obs(
y = lsl.Var.new_obs(
value=jnp.array([1.314, 0.861, -1.813, 0.587, -1.408]),
distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
name="y",
)

model = lsl.Model([loc, scale, y])
model = lsl.Model([y])
```

The model allows us to evaluate the log-probability through a property,
Expand Down Expand Up @@ -101,44 +101,15 @@ builder.add_kernel(gs.NUTSKernel(["loc"]))
builder.set_model(gs.LieselInterface(model))
builder.set_initial_values(model.state)

# we disable the progress bar for a nicer display here in the readme
builder.show_progress = False

builder.set_duration(warmup_duration=1000, posterior_duration=1000)

engine = builder.build()
```

liesel.goose.builder - WARNING - No jitter functions provided. The initial values won't be jittered
liesel.goose.engine - INFO - Initializing kernels...
liesel.goose.engine - INFO - Done

``` python
engine.sample_all_epochs()
```

liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 75 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 2, 1, 2, 0 / 75 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 25 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 1, 1, 1, 1 / 25 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 50 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 1, 1, 1, 1 / 50 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 100 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 1, 2, 2, 1 / 100 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 200 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 1, 4, 1, 1 / 200 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: SLOW_ADAPTATION, 500 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 2, 1, 1, 2 / 500 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Starting epoch: FAST_ADAPTATION, 50 transitions, 25 jitted together
liesel.goose.engine - WARNING - Errors per chain for kernel_00: 1, 1, 2, 2 / 50 transitions
liesel.goose.engine - INFO - Finished epoch
liesel.goose.engine - INFO - Finished warmup
liesel.goose.engine - INFO - Starting epoch: POSTERIOR, 1000 transitions, 25 jitted together
liesel.goose.engine - INFO - Finished epoch

Finally, we can print a summary table and view some diagnostic plots.

``` python
Expand Down Expand Up @@ -228,31 +199,31 @@ loc
kernel_00
</td>
<td>
-0.083
-0.095
</td>
<td>
0.445
0.452
</td>
<td>
-0.810
-0.823
</td>
<td>
-0.091
-0.100
</td>
<td>
0.652
0.664
</td>
<td>
4000
</td>
<td>
1459.234
1488.973
</td>
<td>
1874.643
2085.720
</td>
<td>
1.002
1.001
</td>
</tr>
</tbody>
Expand Down Expand Up @@ -312,10 +283,10 @@ divergent transition
warmup
</th>
<td>
38
47
</td>
<td>
0.009
0.012
</td>
</tr>
<tr>
Expand Down
11 changes: 7 additions & 4 deletions .github/README.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,16 @@ import jax.numpy as jnp
import tensorflow_probability.substrates.jax.distributions as tfd
import liesel.model as lsl
loc = lsl.param(0.0, name="loc")
scale = lsl.param(1.0, name="scale")
loc = lsl.Var.new_param(0.0, name="loc")
scale = lsl.Var.new_param(1.0, name="scale")
y = lsl.obs(
y = lsl.Var.new_obs(
value=jnp.array([1.314, 0.861, -1.813, 0.587, -1.408]),
distribution=lsl.Dist(tfd.Normal, loc=loc, scale=scale),
name="y",
)
model = lsl.Model([loc, scale, y])
model = lsl.Model([y])
```

The model allows us to evaluate the log-probability through a property, which is updated automatically if the value of a node is modified.
Expand All @@ -101,6 +101,9 @@ builder.add_kernel(gs.NUTSKernel(["loc"]))
builder.set_model(gs.LieselInterface(model))
builder.set_initial_values(model.state)
# we disable the progress bar for a nicer display here in the readme
builder.show_progress = False
builder.set_duration(warmup_duration=1000, posterior_duration=1000)
engine = builder.build()
Expand Down
Loading

0 comments on commit a09ae28

Please sign in to comment.