|
1 | 1 | import logging
|
2 | 2 |
|
3 | 3 | import torch
|
| 4 | +import torch.optim as optim |
4 | 5 |
|
5 | 6 | from robustbench.data import load_cifar10c
|
| 7 | +from robustbench.model_zoo.enums import ThreatModel |
| 8 | +from robustbench.utils import load_model |
6 | 9 | from robustbench.utils import clean_accuracy as accuracy
|
7 | 10 |
|
8 |
| -from tent import tent |
| 11 | +import tent |
| 12 | +import norm |
| 13 | + |
9 | 14 | from conf import cfg, load_cfg_fom_args
|
10 | 15 |
|
11 | 16 |
|
| 17 | +logger = logging.getLogger(__name__) |
| 18 | + |
| 19 | + |
12 | 20 | def evaluate(cfg_file):
|
13 | 21 | load_cfg_fom_args(cfg_file=cfg_file,
|
14 | 22 | description="CIFAR-10-C evaluation.")
|
15 |
| - logger = logging.getLogger(__name__) |
| 23 | + # configure model |
| 24 | + base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR, |
| 25 | + cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda() |
| 26 | + if cfg.MODEL.ADAPTATION == "source": |
| 27 | + logger.info("test-time adaptation: NONE") |
| 28 | + model = setup_source(base_model) |
| 29 | + if cfg.MODEL.ADAPTATION == "norm": |
| 30 | + logger.info("test-time adaptation: NORM") |
| 31 | + model = setup_norm(base_model) |
| 32 | + if cfg.MODEL.ADAPTATION == "tent": |
| 33 | + logger.info("test-time adaptation: TENT") |
| 34 | + model = setup_tent(base_model) |
| 35 | + # evaluate on each severity and type of corruption in turn |
16 | 36 | for severity in cfg.CORRUPTION.SEVERITY:
|
17 | 37 | for corruption_type in cfg.CORRUPTION.TYPE:
|
| 38 | + # reset adaptation for each combination of corruption x severity |
| 39 | + # note: for evaluation protocol, but not necessarily needed |
| 40 | + try: |
| 41 | + model.reset() |
| 42 | + logger.info("resetting model") |
| 43 | + except: |
| 44 | + logger.warning("not resetting model") |
18 | 45 | x_test, y_test = load_cifar10c(cfg.CORRUPTION.NUM_EX,
|
19 | 46 | severity, cfg.DATA_DIR, False,
|
20 | 47 | [corruption_type])
|
21 | 48 | x_test, y_test = x_test.cuda(), y_test.cuda()
|
22 |
| - model = tent(cfg.CORRUPTION.MODEL) |
23 |
| - acc = accuracy(model, x_test, y_test, cfg.OPTIM.BATCH_SIZE) |
24 |
| - logger.info('accuracy [{}{}]: {:.2%}'.format( |
25 |
| - corruption_type, severity, acc)) |
| 49 | + acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE) |
| 50 | + err = 1. - acc |
| 51 | + logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}") |
| 52 | + |
| 53 | + |
| 54 | +def setup_source(model): |
| 55 | + """Set up the baseline source model without adaptation.""" |
| 56 | + model.eval() |
| 57 | + logger.info(f"model for evaluation: %s", model) |
| 58 | + return model |
| 59 | + |
| 60 | + |
| 61 | +def setup_norm(model): |
| 62 | + """Set up test-time normalization adaptation. |
| 63 | +
|
| 64 | + Adapt by normalizing features with test batch statistics. |
| 65 | + The statistics are measured independently for each batch; |
| 66 | + no running average or other cross-batch estimation is used. |
| 67 | + """ |
| 68 | + norm_model = norm.Norm(model) |
| 69 | + logger.info(f"model for adaptation: %s", model) |
| 70 | + stats, stat_names = norm.collect_stats(model) |
| 71 | + logger.info(f"stats for adaptation: %s", stat_names) |
| 72 | + return norm_model |
| 73 | + |
| 74 | + |
| 75 | +def setup_tent(model): |
| 76 | + """Set up tent adaptation. |
| 77 | +
|
| 78 | + Configure the model for training + feature modulation by batch statistics, |
| 79 | + collect the parameters for feature modulation by gradient optimization, |
| 80 | + set up the optimizer, and then tent the model. |
| 81 | + """ |
| 82 | + model = tent.configure_model(model) |
| 83 | + params, param_names = tent.collect_params(model) |
| 84 | + optimizer = setup_optimizer(params) |
| 85 | + tent_model = tent.Tent(model, optimizer, |
| 86 | + steps=cfg.OPTIM.STEPS, |
| 87 | + episodic=cfg.MODEL.EPISODIC) |
| 88 | + logger.info(f"model for adaptation: %s", model) |
| 89 | + logger.info(f"params for adaptation: %s", param_names) |
| 90 | + logger.info(f"optimizer for adaptation: %s", optimizer) |
| 91 | + return tent_model |
| 92 | + |
| 93 | + |
| 94 | +def setup_optimizer(params): |
| 95 | + """Set up optimizer for tent adaptation. |
| 96 | +
|
| 97 | + Tent needs an optimizer for test-time entropy minimization. |
| 98 | + In principle, tent could make use of any gradient optimizer. |
| 99 | + In practice, we advise choosing Adam or SGD+momentum. |
| 100 | + For optimization settings, we advise to use the settings from the end of |
| 101 | + trainig, if known, or start with a low learning rate (like 0.001) if not. |
| 102 | +
|
| 103 | + For best results, try tuning the learning rate and batch size. |
| 104 | + """ |
| 105 | + if cfg.OPTIM.METHOD == 'Adam': |
| 106 | + return optim.Adam(params, |
| 107 | + lr=cfg.OPTIM.LR, |
| 108 | + betas=(cfg.OPTIM.BETA, 0.999), |
| 109 | + weight_decay=cfg.OPTIM.WD) |
| 110 | + elif cfg.OPTIM.METHOD == 'SGD': |
| 111 | + return optim.SGD(params, |
| 112 | + lr=cfg.OPTIM.LR, |
| 113 | + momentum=cfg.OPTIM.MOMENTUM, |
| 114 | + dampening=cfg.OPTIM.DAMPENING, |
| 115 | + weight_decay=cfg.OPTIM.WD, |
| 116 | + nesterov=cfg.OPTIM.NESTEROV) |
| 117 | + else: |
| 118 | + raise NotImplementedError |
26 | 119 |
|
27 | 120 |
|
28 | 121 | if __name__ == '__main__':
|
|
0 commit comments