Skip to content
This repository was archived by the owner on Feb 15, 2025. It is now read-only.

Commit fc5705f

Browse files
committed
turn adaptation methods into modules, simplify, and clarify wording
- clarify status of example and reference code: do try the example! - turn methods into importable modules, decoupled from their configuration and arguments, for ease of experimentation and adoption - comment and document
1 parent 1ffcd70 commit fc5705f

File tree

8 files changed

+327
-268
lines changed

8 files changed

+327
-268
lines changed

README.md

+20-19
Original file line numberDiff line numberDiff line change
@@ -3,36 +3,38 @@
33
This is the official project repository for [Tent: Fully-Test Time Adaptation by Entropy Minimization](https://openreview.net/forum?id=uXl3bZLkr3c) by
44
Dequan Wang\*, Evan Shelhamer\*, Shaoteng Liu, Bruno Olshausen, and Trevor Darrell (ICLR 2021, spotlight).
55

6-
Tent equips a model to adapt itself to new and different data ☀️ 🌧 ❄️ during testing.
7-
Tent updates online and batch-by-batch to reduce error on dataset shifts like corruptions, simulation-to-real discrepancies, and other differences between training and testing data.
6+
⛺️ Tent equips a model to adapt itself to new and different data during testing ☀️ 🌧❄️.
7+
Tented models adapt online and batch-by-batch to reduce error on dataset shifts like corruptions, simulation-to-real discrepancies, and other differences between training and testing data.
8+
This kind of adaptation is effective and efficient: tent makes just one update per batch to not interrupt inference.
89

9-
Our **example code** illustrates the method and provides representative results for image corruptions on CIFAR-10-C.
10-
Note that the exact details of the model, optimization, etc. differ from the paper, so this is not for reproduction, but for explanation.
10+
To illustrate the tent method and fully test-time adaptation setting we provide **example code** for adapting to image corruptions on CIFAR-10-C.
11+
The purpose of the example is explanation, not reproduction: exact details of the model architecture, optimization settings, etc. may differ from the paper.
12+
That said, the results should be representative, so do give it a try and experiment!
1113

12-
Please check back soon for our **reference code** to reproduce and extend tent!
14+
Please check back soon for **reference code** to exactly reproduce the ImageNet-C results in the paper.
1315

1416
## Example: Adapting to Image Corruptions on CIFAR-10-C
1517

16-
This example compares a baseline without adaptation (base), test-time normalization that updates feature statistics during testing (norm), and our method for entropy minimization during testing (tent).
18+
This example compares a baseline without adaptation (source), test-time normalization for updating feature statistics during testing (norm), and our method for entropy minimization during testing (tent).
1719

1820
- Dataset: [CIFAR-10-C](https://github.com/hendrycks/robustness/), with 15 corruption types and 5 levels.
1921
- Model: [WRN-28-10](https://github.com/RobustBench/robustbench), the default model for RobustBench.
2022

2123
**Usage**:
2224

2325
```python
24-
python cifar10c.py --cfg cfgs/base.yaml
26+
python cifar10c.py --cfg cfgs/source.yaml
2527
python cifar10c.py --cfg cfgs/norm.yaml
2628
python cifar10c.py --cfg cfgs/tent.yaml
2729
```
2830

2931
**Result**: tent reduces the error (%) across corruption types at the most severe level of corruption (level 5).
3032

31-
| | mean | gauss_noise | shot_noise | impulse_noise | defocus_blur | glass_blur | motion_blur | zoom_blur | snow | frost | fog | brightness | contrast | elastic_trans | pixelate | jpeg |
32-
| ---------------------------------------------------- | ---: | ----------: | ---------: | ------------: | -----------: | ---------: | ----------: | --------: | ---: | ----: | ---: | ---------: | -------: | ------------: | -------: | ---: |
33-
| [base](./cifar10c.py) | 43.5 | 72.3 | 65.7 | 72.9 | 46.9 | 54.3 | 34.8 | 42.0 | 25.1 | 41.3 | 26.0 | 9.3 | 46.7 | 26.6 | 58.5 | 30.3 |
34-
| [norm](./norm.py) | 20.4 | 28.1 | 26.1 | 36.3 | 12.8 | 35.3 | 14.2 | 12.1 | 17.3 | 17.4 | 15.3 | 8.4 | 12.6 | 23.8 | 19.7 | 27.3 |
35-
| [tent](./tent.py) | 18.6 | 24.8 | 23.5 | 33.0 | 11.9 | 31.9 | 13.7 | 10.8 | 15.9 | 16.2 | 13.7 | 7.9 | 12.1 | 22.0 | 17.3 | 24.2 |
33+
| | mean | gauss_noise | shot_noise | impulse_noise | defocus_blur | glass_blur | motion_blur | zoom_blur | snow | frost | fog | brightness | contrast | elastic_trans | pixelate | jpeg |
34+
| ---------------------------------------------------------- | ---: | ----------: | ---------: | ------------: | -----------: | ---------: | ----------: | --------: | ---: | ----: | ---: | ---------: | -------: | ------------: | -------: | ---: |
35+
| source [code](./cifar10c.py) [config](./cfgs/source.yaml) | 43.5 | 72.3 | 65.7 | 72.9 | 46.9 | 54.3 | 34.8 | 42.0 | 25.1 | 41.3 | 26.0 | 9.3 | 46.7 | 26.6 | 58.5 | 30.3 |
36+
| norm [code](./norm.py) [config](./cfgs/norm.yaml) | 20.4 | 28.1 | 26.1 | 36.3 | 12.8 | 35.3 | 14.2 | 12.1 | 17.3 | 17.4 | 15.3 | 8.4 | 12.6 | 23.8 | 19.7 | 27.3 |
37+
| tent [code](./tent.py) [config](./cfgs/tent.yaml) | 18.6 | 24.8 | 23.5 | 33.0 | 12.0 | 31.8 | 13.7 | 10.8 | 15.9 | 16.2 | 13.7 | 7.9 | 12.1 | 22.0 | 17.3 | 24.2 |
3638

3739
See the full results for this example in the [wandb report](https://wandb.ai/tent/cifar10c).
3840

@@ -45,12 +47,11 @@ Please contact Dequan Wang and Evan Shelhamer at dqwang AT cs.berkeley.edu and s
4547
If the tent method or fully test-time adaptation setting are helpful in your research, please consider citing our paper:
4648

4749
```bibtex
48-
@inproceedings{
49-
wang2021tent,
50-
title={Tent: Fully Test-Time Adaptation by Entropy Minimization},
51-
author={Dequan Wang and Evan Shelhamer and Shaoteng Liu and Bruno Olshausen and Trevor Darrell},
52-
booktitle={International Conference on Learning Representations},
53-
year={2021},
54-
url={https://openreview.net/forum?id=uXl3bZLkr3c}
50+
@inproceedings{wang2021tent,
51+
title={Tent: Fully Test-Time Adaptation by Entropy Minimization},
52+
author={Dequan Wang and Evan Shelhamer and Shaoteng Liu and Bruno Olshausen and Trevor Darrell},
53+
booktitle={International Conference on Learning Representations},
54+
year={2021},
55+
url={https://openreview.net/forum?id=uXl3bZLkr3c}
5556
}
5657
```

cfgs/norm.yaml

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
MODEL:
2+
ADAPTATION: norm
3+
ARCH: Standard
4+
TEST:
5+
BATCH_SIZE: 200
16
CORRUPTION:
2-
MODEL: Standard
3-
EVAL_ONLY: True
7+
DATASET: cifar10
48
SEVERITY:
59
- 5
610
- 4
@@ -23,12 +27,3 @@ CORRUPTION:
2327
- elastic_transform
2428
- pixelate
2529
- jpeg_compression
26-
BN:
27-
FUNC: TrainModeBatchNorm2d
28-
OPTIM:
29-
BATCH_SIZE: 200
30-
METHOD: Adam
31-
ITER: 1
32-
BETA: 0.9
33-
LR: 1e-3
34-
WD: 0.

cfgs/base.yaml cfgs/source.yaml

+6-11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
MODEL:
2+
ADAPTATION: source
3+
ARCH: Standard
4+
TEST:
5+
BATCH_SIZE: 200
16
CORRUPTION:
2-
MODEL: Standard
3-
EVAL_ONLY: True
7+
DATASET: cifar10
48
SEVERITY:
59
- 5
610
- 4
@@ -23,12 +27,3 @@ CORRUPTION:
2327
- elastic_transform
2428
- pixelate
2529
- jpeg_compression
26-
BN:
27-
FUNC: FrozenMeanVarBatchNorm2d
28-
OPTIM:
29-
BATCH_SIZE: 200
30-
METHOD: Adam
31-
ITER: 1
32-
BETA: 0.9
33-
LR: 1e-3
34-
WD: 0.

cfgs/tent.yaml

+7-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
MODEL:
2+
ADAPTATION: tent
3+
ARCH: Standard
4+
TEST:
5+
BATCH_SIZE: 200
16
CORRUPTION:
2-
MODEL: Standard
3-
EVAL_ONLY: False
7+
DATASET: cifar10
48
SEVERITY:
59
- 5
610
- 4
@@ -23,12 +27,9 @@ CORRUPTION:
2327
- elastic_transform
2428
- pixelate
2529
- jpeg_compression
26-
BN:
27-
FUNC: TrainModeBatchNorm2d
2830
OPTIM:
29-
BATCH_SIZE: 200
3031
METHOD: Adam
31-
ITER: 1
32+
STEPS: 1
3233
BETA: 0.9
3334
LR: 1e-3
3435
WD: 0.

cifar10c.py

+99-6
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,121 @@
11
import logging
22

33
import torch
4+
import torch.optim as optim
45

56
from robustbench.data import load_cifar10c
7+
from robustbench.model_zoo.enums import ThreatModel
8+
from robustbench.utils import load_model
69
from robustbench.utils import clean_accuracy as accuracy
710

8-
from tent import tent
11+
import tent
12+
import norm
13+
914
from conf import cfg, load_cfg_fom_args
1015

1116

17+
logger = logging.getLogger(__name__)
18+
19+
1220
def evaluate(cfg_file):
1321
load_cfg_fom_args(cfg_file=cfg_file,
1422
description="CIFAR-10-C evaluation.")
15-
logger = logging.getLogger(__name__)
23+
# configure model
24+
base_model = load_model(cfg.MODEL.ARCH, cfg.CKPT_DIR,
25+
cfg.CORRUPTION.DATASET, ThreatModel.corruptions).cuda()
26+
if cfg.MODEL.ADAPTATION == "source":
27+
logger.info("test-time adaptation: NONE")
28+
model = setup_source(base_model)
29+
if cfg.MODEL.ADAPTATION == "norm":
30+
logger.info("test-time adaptation: NORM")
31+
model = setup_norm(base_model)
32+
if cfg.MODEL.ADAPTATION == "tent":
33+
logger.info("test-time adaptation: TENT")
34+
model = setup_tent(base_model)
35+
# evaluate on each severity and type of corruption in turn
1636
for severity in cfg.CORRUPTION.SEVERITY:
1737
for corruption_type in cfg.CORRUPTION.TYPE:
38+
# reset adaptation for each combination of corruption x severity
39+
# note: for evaluation protocol, but not necessarily needed
40+
try:
41+
model.reset()
42+
logger.info("resetting model")
43+
except:
44+
logger.warning("not resetting model")
1845
x_test, y_test = load_cifar10c(cfg.CORRUPTION.NUM_EX,
1946
severity, cfg.DATA_DIR, False,
2047
[corruption_type])
2148
x_test, y_test = x_test.cuda(), y_test.cuda()
22-
model = tent(cfg.CORRUPTION.MODEL)
23-
acc = accuracy(model, x_test, y_test, cfg.OPTIM.BATCH_SIZE)
24-
logger.info('accuracy [{}{}]: {:.2%}'.format(
25-
corruption_type, severity, acc))
49+
acc = accuracy(model, x_test, y_test, cfg.TEST.BATCH_SIZE)
50+
err = 1. - acc
51+
logger.info(f"error % [{corruption_type}{severity}]: {err:.2%}")
52+
53+
54+
def setup_source(model):
55+
"""Set up the baseline source model without adaptation."""
56+
model.eval()
57+
logger.info(f"model for evaluation: %s", model)
58+
return model
59+
60+
61+
def setup_norm(model):
62+
"""Set up test-time normalization adaptation.
63+
64+
Adapt by normalizing features with test batch statistics.
65+
The statistics are measured independently for each batch;
66+
no running average or other cross-batch estimation is used.
67+
"""
68+
norm_model = norm.Norm(model)
69+
logger.info(f"model for adaptation: %s", model)
70+
stats, stat_names = norm.collect_stats(model)
71+
logger.info(f"stats for adaptation: %s", stat_names)
72+
return norm_model
73+
74+
75+
def setup_tent(model):
76+
"""Set up tent adaptation.
77+
78+
Configure the model for training + feature modulation by batch statistics,
79+
collect the parameters for feature modulation by gradient optimization,
80+
set up the optimizer, and then tent the model.
81+
"""
82+
model = tent.configure_model(model)
83+
params, param_names = tent.collect_params(model)
84+
optimizer = setup_optimizer(params)
85+
tent_model = tent.Tent(model, optimizer,
86+
steps=cfg.OPTIM.STEPS,
87+
episodic=cfg.MODEL.EPISODIC)
88+
logger.info(f"model for adaptation: %s", model)
89+
logger.info(f"params for adaptation: %s", param_names)
90+
logger.info(f"optimizer for adaptation: %s", optimizer)
91+
return tent_model
92+
93+
94+
def setup_optimizer(params):
95+
"""Set up optimizer for tent adaptation.
96+
97+
Tent needs an optimizer for test-time entropy minimization.
98+
In principle, tent could make use of any gradient optimizer.
99+
In practice, we advise choosing Adam or SGD+momentum.
100+
For optimization settings, we advise to use the settings from the end of
101+
trainig, if known, or start with a low learning rate (like 0.001) if not.
102+
103+
For best results, try tuning the learning rate and batch size.
104+
"""
105+
if cfg.OPTIM.METHOD == 'Adam':
106+
return optim.Adam(params,
107+
lr=cfg.OPTIM.LR,
108+
betas=(cfg.OPTIM.BETA, 0.999),
109+
weight_decay=cfg.OPTIM.WD)
110+
elif cfg.OPTIM.METHOD == 'SGD':
111+
return optim.SGD(params,
112+
lr=cfg.OPTIM.LR,
113+
momentum=cfg.OPTIM.MOMENTUM,
114+
dampening=cfg.OPTIM.DAMPENING,
115+
weight_decay=cfg.OPTIM.WD,
116+
nesterov=cfg.OPTIM.NESTEROV)
117+
else:
118+
raise NotImplementedError
26119

27120

28121
if __name__ == '__main__':

0 commit comments

Comments
 (0)