This repository contains a JAX-based implementation of various off-policy reinforcement learning (RL) algorithms, focusing on leveraging JAX for efficiency.
- Efficient JAX Implementation: Optimized for speed and performance.
- Clean and Simple Code: Designed for clarity and ease of understanding.
- Comparison with PyTorch: Includes benchmarks comparing training speed against PyTorch implementations.
- TD7
- SALE-TQC : SALE Representation (TD7) + TQC
- SIMBA
Plots showing the performance of different algorithms:
- X-axis: Timestep
- Y-axis: Performance
A performance comparison between JAX and PyTorch implementations in terms of training speed.
# Clone the repository
git clone https://github.com/seungju-k1m/jax-offpolicy-rl.git
cd jax-offpolicy-rl
# Install dependencies
rye sync
rye run python cli.py td7 --env-id Humanoid-v4 --save-path "save/TD7" --seed 1 --use-progressbar
The repository includes scripts to visualize learning curves and compare training efficiency.