Skip to content

Commit

Permalink
refactor EvolutionMutator
Browse files Browse the repository at this point in the history
  • Loading branch information
marsggbo committed May 30, 2023
1 parent 44b0e6e commit f54784a
Showing 1 changed file with 67 additions and 72 deletions.
139 changes: 67 additions & 72 deletions hyperbox/mutator/evolution_mutator.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,15 @@ def is_constraint_satisfied(constraint, obj):


class EvolutionMutator(RandomMutator):
def __init__(
def __init__(self, model):
super(EvolutionMutator, self).__init__(model)

def search(
self,
model,
eval_func: callable, # evaluation function of arch performance
eval_kwargs: dict, # kwargs of eval_func
eval_metrics_order: dict, # order of eval metrics, e.g. {'accuracy': 'max', 'f1-score': 'max'}

warmup_epochs: int=0,
evolution_epochs: int=100,
population_num: int=50,
Expand All @@ -54,11 +60,16 @@ def __init__(
mutation_num: Optional[Union[int, float]]=0.5,
mutation_prob: float=0.3,
topk: Optional[Union[int, float]]=10,
*args, **kwargs
):
'''Init Args:
model: model (Supernet) to be searched

flops_limit: Optional[Union[list, float]]=None, # MFLOPs, None means no limit
size_limit: Optional[Union[list, float]]=None, # MB, None means no limit
log_dir: str='evolution_logs',
resume_from_checkpoint: Optional[str]=None,
to_save_checkpoint: bool=True,
to_plot_pareto: bool=True,
figname: str='evolution_pareto.pdf'
):
'''Args:
####Evolution parameters####
warmup_epochs: number of warmup epochs
evolution_epochs: number of evolution epochs
Expand All @@ -72,8 +83,23 @@ def __init__(
mutation_num: mutation num for each epoch
mutation_prob: mutation probability
topk: top k candidates
####Evaluation function####
eval_func: evaluation function of arch performance
eval_kwargs: kwargs of eval_func, must contain arguments of `model` and `mutator`
eval_metrics_order: order of eval metrics, e.g. {'accuracy': 'max', 'f1-score': 'max'}
####Limitation####
flops_limit: flops limit for each epoch
size_limit: size limit for each epoch
####Logging parameters####
log_dir: log directory
resume_from_checkpoint: resume from checkpoint
to_save_checkpoint: save checkpoint or not
to_plot_pareto: plot pareto or not
figname: pareto figure name
'''
super(EvolutionMutator, self).__init__(model)
if selection_alg=='best' and len(eval_metrics_order)>1:
raise ValueError('`selection_alg` must be `nsga2` when there are more than one eval metrics')

Expand All @@ -87,73 +113,15 @@ def __init__(
self.mutation_num = get_int_num(mutation_num, population_num)
self.mutation_prob = mutation_prob
self.random_num = population_num - (self.crossover_num + self.mutation_num)
self.topk = get_int_num(topk, population_num)

self.memory = []
self.vis_dict = {}
self.keep_top_k = {self.selection_num: [], self.topk: []}
self.epoch = 0
self.candidates = []

def save_checkpoint(self, epoch=None):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
info = {}
info['memory'] = self.memory
info['candidates'] = self.candidates
info['vis_dict'] = self.vis_dict
info['keep_top_k'] = self.keep_top_k
info['epoch'] = self.epoch
if epoch is not None:
ckpt_name = f"{self.log_dir}/checkpoint_{epoch}.pth.tar"
else:
ckpt_name = f"{self.log_dir}/{self.checkpoint_name}"
torch.save(info, ckpt_name)
log.info(f'save checkpoint to {ckpt_name}')

def load_checkpoint(self):
if not os.path.exists(self.resume_from_checkpoint):
return False
info = torch.load(self.resume_from_checkpoint)
self.memory = info['memory']
self.candidates = info['candidates']
self.vis_dict = info['vis_dict']
self.keep_top_k = info['keep_top_k']
self.epoch = info['epoch']

log.info(f'load checkpoint from {self.resume_from_checkpoint}')
return True

def search(
self,
eval_func: callable, # evaluation function of arch performance
eval_kwargs: dict, # kwargs of eval_func
eval_metrics_order: dict, # order of eval metrics, e.g. {'accuracy': 'max', 'f1-score': 'max'}
flops_limit: Optional[Union[list, float]]=None, # MFLOPs, None means no limit
size_limit: Optional[Union[list, float]]=None, # MB, None means no limit
log_dir: str='evolution_logs',
resume_from_checkpoint: Optional[str]=None,
to_save_checkpoint: bool=True,
to_plot_pareto: bool=True,
figname: str='evolution_pareto.pdf',
):
'''Args:
####Evaluation function####
eval_func: evaluation function of arch performance
eval_kwargs: kwargs of eval_func, must contain arguments of `model` and `mutator`
eval_metrics_order: order of eval metrics, e.g. {'accuracy': 'max', 'f1-score': 'max'}
####Limitation####
flops_limit: flops limit for each epoch
size_limit: size limit for each epoch
self.topk = get_int_num(topk, self.population_num)
self.keep_top_k = {self.selection_num: [], self.topk: []}

####Logging parameters####
log_dir: log directory
resume_from_checkpoint: resume from checkpoint
to_save_checkpoint: save checkpoint or not
to_plot_pareto: plot pareto or not
figname: pareto figure name
'''
self.eval_func = eval_func
self.eval_kwargs = eval_kwargs
self.eval_metrics_order = eval_metrics_order
Expand Down Expand Up @@ -434,6 +402,35 @@ def crossover(self, arch1, arch2, crossover_prob):
cross_arch[key] = deepcopy(arch2[key])
return cross_arch

def save_checkpoint(self, epoch=None):
if not os.path.exists(self.log_dir):
os.makedirs(self.log_dir)
info = {}
info['memory'] = self.memory
info['candidates'] = self.candidates
info['vis_dict'] = self.vis_dict
info['keep_top_k'] = self.keep_top_k
info['epoch'] = self.epoch
if epoch is not None:
ckpt_name = f"{self.log_dir}/checkpoint_{epoch}.pth.tar"
else:
ckpt_name = f"{self.log_dir}/{self.checkpoint_name}"
torch.save(info, ckpt_name)
log.info(f'save checkpoint to {ckpt_name}')

def load_checkpoint(self):
if not os.path.exists(self.resume_from_checkpoint):
return False
info = torch.load(self.resume_from_checkpoint)
self.memory = info['memory']
self.candidates = info['candidates']
self.vis_dict = info['vis_dict']
self.keep_top_k = info['keep_top_k']
self.epoch = info['epoch']

log.info(f'load checkpoint from {self.resume_from_checkpoint}')
return True

@classmethod
def plot_real_proxy_metrics(
cls,
Expand Down Expand Up @@ -527,8 +524,9 @@ def eval_func(mutator, network, da=32,gs=5432,gsrh=764):
net = NASBench201Network(num_classes=10).cuda()
# net = NASBenchMBNet(num_classes=10).cuda()

em = EvolutionMutator(
net,
em = EvolutionMutator(net)

topk = em.search(
warmup_epochs=0,
evolution_epochs=2,
population_num=30,
Expand All @@ -539,9 +537,6 @@ def eval_func(mutator, network, da=32,gs=5432,gsrh=764):
mutation_num=0.2,
mutation_prob=0.3,
topk=5,
)

topk = em.search(
eval_func=eval_func,
eval_kwargs={'da':352,'gs':32,'gsrh':764},
eval_metrics_order=order,
Expand Down

0 comments on commit f54784a

Please sign in to comment.