Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance regression due to jaxlib v0.4.32 XLA backend update #242

Open
1 task
wiep opened this issue Feb 18, 2025 · 0 comments
Open
1 task

Performance regression due to jaxlib v0.4.32 XLA backend update #242

wiep opened this issue Feb 18, 2025 · 0 comments
Labels
comp:goose This issue is related to the goose module performance Something is working but the code is slow

Comments

@wiep
Copy link
Contributor

wiep commented Feb 18, 2025

With the release of jaxlib v0.4.32, the jax has transitioned to a new XLA backend for cpus (see the changelog). This update may introduce a performance regression, particularly affecting the NUTS and HMC kernel.

As a potential temporary workaround, setting the XLA flag --xla_cpu_use_thunk_runtime=false can help mitigate the slowdown. You can apply this fix by running your code with the following environment variable:

XLA_FLAGS=--xla_cpu_use_thunk_runtime=false

Related

Related issues and discussions:

blackjax

Discussions in jax repo

@wiep wiep added comp:goose This issue is related to the goose module performance Something is working but the code is slow labels Feb 18, 2025
@wiep wiep changed the title Performance regression in due to jaxlib v0.4.32 XLA backend update Performance regression due to jaxlib v0.4.32 XLA backend update Feb 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:goose This issue is related to the goose module performance Something is working but the code is slow
Projects
None yet
Development

No branches or pull requests

1 participant