Performance regression due to jaxlib v0.4.32 XLA backend update #242
Labels
comp:goose
This issue is related to the goose module
performance
Something is working but the code is slow
With the release of jaxlib v0.4.32, the jax has transitioned to a new XLA backend for cpus (see the changelog). This update may introduce a performance regression, particularly affecting the NUTS and HMC kernel.
As a potential temporary workaround, setting the XLA flag
--xla_cpu_use_thunk_runtime=false
can help mitigate the slowdown. You can apply this fix by running your code with the following environment variable:Related
Related issues and discussions:
blackjax
Discussions in jax repo
The text was updated successfully, but these errors were encountered: