MetaDE is an advanced evolutionary framework that dynamically optimizes the strategies and hyperparameters of Differential Evolution (DE) through meta-level evolution. By leveraging DE to fine-tune its own configurations, MetaDE adapts mutation and crossover strategies to suit varying problem landscapes in real-time. With GPU acceleration, it handles large-scale, complex black-box optimization problems efficiently, delivering faster convergence and superior performance. MetaDE is compatible with the EvoX framework.
New in this version:
MetaDE now supports both JAX and PyTorch backends.
- If you want to run PyTorch-based code, please install the GPU-enabled PyTorch (if GPU acceleration is desired) along with EvoX ≥ 1.0.0.
- Important: The PyTorch version currently does not support Brax-based RL tasks. If you need Brax support (or wish to use the version from the paper experiments), please use the JAX backend with a CUDA-enabled JAX (and Brax) installation.
- Meta-level Evolution 🌱: Uses DE at a meta-level to evolve hyperparameters and strategies of DE applied at a problem-solving level.
- Parameterized DE (PDE) 🛠️: A customizable variant of DE that offers dynamic mutation and crossover strategies adaptable to different optimization problems.
- Multi-Backend Support 🔥: Provides both JAX and PyTorch implementations for broader hardware/software compatibility.
- GPU-Accelerated 🚀: Integrated with GPU acceleration on both JAX and PyTorch, enabling efficient large-scale optimizations.
- End-to-End Optimization 🔄: MetaDE provides a seamless workflow from hyperparameter tuning to solving optimization problems in a fully automated process.
- Wide Applicability 🤖: Supports various benchmarks, including CEC2022, and real-world tasks like evolutionary reinforcement learning in robotics.
Using the MetaDE algorithm to solve RL tasks.
The following animations show the behaviors in Brax environments:
![]() |
![]() |
![]() |
Hopper | Swimmer | Reacher |
- Hopper: Aiming for maximum speed and jumping height.
- Swimmer: Enhancing movement efficiency in fluid environments.
- Reacher: Moving the fingertip to a random target.
Depending on which backend you plan to use (JAX or PyTorch), you should install the proper libraries and GPU dependencies:
-
Common:
- evox (version ≥ 1.0.0 for PyTorch support)
-
JAX-based version:
jax >= 0.4.16
jaxlib >= 0.3.0
brax == 0.10.3
(optional, if you want to run Brax RL problems)
-
PyTorch-based version:
torch
(GPU version recommended, e.g.torch>=2.5.0
)torchvision
,torchaudio
(optional, depending on your environment/needs)
You can install MetaDE with either the JAX or PyTorch backend (or both).
Below are some example installation steps; please adapt versions as needed:
- Install PyTorch (with CUDA, if you want GPU acceleration). For example, if you have CUDA 12.4, you might do:
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
- Install EvoX ≥ 1.0.0 (for PyTorch support):
pip install git+https://github.com/EMI-Group/evox.git@v1.0.1
- Install MetaDE:
pip install git+https://github.com/EMI-Group/metade.git
-
Install JAX. We recommend
jax >= 0.4.16
.For cpu version only, you may use:
pip install -U jax
For nvidia gpus, you may use:
pip install -U jax[cuda12]
For details of installing jax, please check https://github.com/google/jax.
-
Install EvoX ≥ 1.0.0 (for PyTorch support):
pip install git+https://github.com/EMI-Group/evox.git@v1.0.1
-
Install MetaDE:
pip install git+https://github.com/EMI-Group/metade.git
-
Install Brax(Optional, if you want to solve Brax RL problems):
pip install brax==0.10.3
MetaDE employs Differential Evolution (DE) as the evolver to optimize the parameters of its executor.
- Mutation: DE's mutation strategies evolve based on feedback from the problem landscape.
- Crossover: Different crossover strategies (binomial, exponential, arithmetic) can be used and adapted.
The executor is a Parameterized Differential Evolution (PDE), a variant of DE designed to accommodate various mutation and crossover strategies dynamically.
- Parameterization: Flexible mutation strategies like
DE/rand/1/bin
orDE/best/2/exp
can be selected based on problem characteristics. - Parallel Execution: Core operations of PDE are optimized for parallel execution on GPUs(via JAX or PyTorch).
MetaDE integrates with the EvoX framework for distributed, GPU-accelerated evolutionary computation, significantly enhancing performance on large-scale optimization tasks.
import jax.numpy as jnp
import jax
from tqdm import tqdm
from util import StdSOMonitor, StdWorkflow
from algorithms.jax import create_batch_algorithm, decoder_de, MetaDE, ParamDE, DE
from problems.jax.sphere import Sphere
D = 10
BATCH_SIZE = 100
NUM_RUNS = 1
key_start = 42
STEPS = 50
POP_SIZE = BATCH_SIZE
BASE_ALG_POP_SIZE = 100
BASE_ALG_STEPS = 100
tiny_num = 1e-5
param_lb = jnp.array([0, 0, 0, 0, 1, 0])
param_ub = jnp.array([1, 1, 4 - tiny_num, 4 - tiny_num, 5 - tiny_num, 3 - tiny_num])
evolver = DE(
lb=param_lb,
ub=param_ub,
pop_size=POP_SIZE,
base_vector="rand",
differential_weight=0.5,
cross_probability=0.9
)
BatchDE = create_batch_algorithm(ParamDE, BATCH_SIZE, NUM_RUNS)
batch_de = BatchDE(
lb=jnp.full((D,), -100),
ub=jnp.full((D,), 100),
pop_size=BASE_ALG_POP_SIZE,
)
base_problem = Sphere()
decoder = decoder_de
key = jax.random.PRNGKey(key_start)
monitor = StdSOMonitor(record_fit_history=False)
meta_problem = MetaDE(
batch_de,
base_problem,
batch_size=BATCH_SIZE,
num_runs=NUM_RUNS,
base_alg_steps=BASE_ALG_STEPS
)
workflow = StdWorkflow(
algorithm=evolver,
problem=meta_problem,
pop_transform=decoder,
monitor=monitor,
record_pop=True,
)
key, subkey = jax.random.split(key)
state = workflow.init(subkey)
power_up = 0
last_iter = False
for step in tqdm(range(STEPS)):
state = state.update_child("problem", {"power_up": power_up})
state = workflow.step(state)
if step == STEPS - 1:
power_up = 1
if last_iter:
break
last_iter = True
print(f"Best fitness: {monitor.get_best_fitness()}")
If you want to use the PyTorch backend, please refer to the PyTorch examples under
examples/pytorch/example.py
in this repository.
MetaDE supports several benchmark suites such as CEC2022. Here’s an example (JAX-based) for the CEC2022 test suite:
import jax.numpy as jnp
import jax
from tqdm import tqdm
from util import (
StdSOMonitor,
StdWorkflow
)
from algorithms.jax import create_batch_algorithm, decoder_de, MetaDE, ParamDE, DE
from problems.jax import CEC2022TestSuit
D = 10
FUNC_LIST = jnp.arange(12) + 1
BATCH_SIZE = 100
NUM_RUNS = 1
key_start = 42
STEPS = 50
POP_SIZE = BATCH_SIZE
BASE_ALG_POP_SIZE = 100
BASE_ALG_STEPS = 100
tiny_num = 1e-5
param_lb = jnp.array([0, 0, 0, 0, 1, 0])
param_ub = jnp.array([1, 1, 4 - tiny_num, 4 - tiny_num, 5 - tiny_num, 3 - tiny_num])
evolver = DE(
lb=param_lb,
ub=param_ub,
pop_size=POP_SIZE,
base_vector="rand", differential_weight=0.5, cross_probability=0.9
)
BatchDE = create_batch_algorithm(ParamDE, BATCH_SIZE, NUM_RUNS)
batch_de = BatchDE(
lb=jnp.full((D,), -100),
ub=jnp.full((D,), 100),
pop_size=BASE_ALG_POP_SIZE,
)
for func_num in FUNC_LIST:
base_problem = CEC2022TestSuit.create(int(func_num))
decoder = decoder_de
key = jax.random.PRNGKey(key_start)
monitor = StdSOMonitor(record_fit_history=False)
print(type(base_problem).__name__)
meta_problem = MetaDE(
batch_de,
base_problem,
batch_size=BATCH_SIZE,
num_runs=NUM_RUNS,
base_alg_steps=BASE_ALG_STEPS
)
workflow = StdWorkflow(
algorithm=evolver,
problem=meta_problem,
pop_transform=decoder,
monitor=monitor,
record_pop=True,
)
key, _ = jax.random.split(key)
state = workflow.init(key)
power_up = 0
last_iter = 0
for i in tqdm(range(STEPS)):
state = state.update_child("problem", {"power_up": power_up})
state = workflow.step(state)
if i == STEPS - 1:
power_up = 1
if last_iter:
break
last_iter = 1
steps_iter = i + 1
print(f"Best_fitness: {monitor.get_best_fitness()}")
If you want to use the PyTorch backend, please refer to the PyTorch examples under
examples/pytorch/example_cec2022.py
in this repository.
MetaDE can also be used for real-world tasks like robotic control through evolutionary reinforcement learning. Note that running these examples requires installing CUDA-enabled JAX and brax==0.10.3
.
from tqdm import tqdm
import problems
from jax import random
from flax import linen as nn
import jax.numpy as jnp
import jax
from util import StdSOMonitor, StdWorkflow, TreeAndVector, parse_opt_direction
from algorithms.jax import create_batch_algorithm, decoder_de, MetaDE, ParamDE, DE
steps = 20
pop_size = 100
key = jax.random.PRNGKey(42)
class HopperPolicy(nn.Module):
@nn.compact
def __call__(self, x):
x = x.astype(jnp.float32)
x = x.reshape(-1)
x = nn.Dense(32)(x)
x = nn.tanh(x)
x = nn.Dense(32)(x)
x = nn.tanh(x)
x = nn.Dense(3)(x)
x = nn.tanh(x)
return x
model = HopperPolicy()
weights = model.init(random.PRNGKey(42), jnp.zeros((11,)))
adapter = TreeAndVector(weights)
vector_form_weights = adapter.to_vector(weights)
tiny_num = 1e-5
param_lb = jnp.array([0, 0, 0, 0, 1, 0])
param_ub = jnp.array([1, 1, 4 - tiny_num, 4 - tiny_num, 5 - tiny_num, 3 - tiny_num])
algorithm = DE(
lb=param_lb,
ub=param_ub,
pop_size=pop_size,
base_vector="rand", differential_weight=0.5, cross_probability=0.9
)
BatchDE = create_batch_algorithm(ParamDE, pop_size, 1)
batch_de = BatchDE(
lb=jnp.full((vector_form_weights.shape[0],), -10.0),
ub=jnp.full((vector_form_weights.shape[0],), 10.0),
pop_size=100,
)
base_problem = problems.jax.Brax(
env_name="hopper",
policy=jax.jit(model.apply),
cap_episode=500,
)
meta_problem = MetaDE(
batch_de,
base_problem,
batch_size=pop_size,
num_runs=1,
base_alg_steps=50,
base_opt_direction="max",
base_pop_transform=adapter.batched_to_tree,
)
monitor = StdSOMonitor(record_fit_history=False)
workflow = StdWorkflow(
algorithm=algorithm,
problem=meta_problem,
monitor=monitor,
pop_transform=decoder_de,
record_pop=True,
)
monitor.set_opt_direction(parse_opt_direction("max"))
key, _ = jax.random.split(key)
state = workflow.init(key)
for i in tqdm(range(steps)):
power_up = 1 if i == steps - 1 else 0
state = state.update_child("problem", {"power_up": power_up})
state = workflow.step(state)
print(f"Best fitness: {monitor.get_best_fitness()}")
Pytorch-based MetaDE does not support Brax problem yet currently.
- Engage in discussions and share your experiences on GitHub Discussion Board.
- Join our QQ group (ID: 297969717).
@article{metade,
title = {{MetaDE}: Evolving Differential Evolution by Differential Evolution},
author = {Chen, Minyang and Feng, Chenchen and Cheng, Ran},
journal = {IEEE Transactions on Evolutionary Computation},
year = 2025,
doi = {10.1109/TEVC.2025.3541587}
}