Skip to content

Commit 2a857e9

Browse files
c00wpytorchmergebot
authored andcommitted
config: Add env_name_default and env_name_force to Config (pytorch#138956)
This allows Configs to handle setting their defaults (or overriding themselves) via environment variables. The environment variables are resolved at install time (which is usually import time). This is done 1) to avoid any race conditions between threads etc..., but 2) to help encourage people to just go modify the configs directly, vs overriding environment variables to change pytorch behaviour. Pull Request resolved: pytorch#138956 Approved by: https://github.com/ezyang ghstack dependencies: pytorch#138766
1 parent 1270c78 commit 2a857e9

File tree

4 files changed

+88
-8
lines changed

4 files changed

+88
-8
lines changed

test/dynamo_skips/TestConfigModule.test_env_name_semantics

Whitespace-only changes.

test/test_utils_config_module.py

+37-6
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
# Owner(s): ["module: unknown"]
2+
import os
23
import pickle
34

5+
6+
os.environ["ENV_TRUE"] = "1"
7+
os.environ["ENV_FALSE"] = "0"
8+
49
from torch.testing._internal import fake_config_module as config
510
from torch.testing._internal.common_utils import run_tests, TestCase
611
from torch.utils._config_module import _UNSET_SENTINEL
@@ -70,6 +75,17 @@ def test_reference_semantics(self):
7075
for k in config._config:
7176
config._config[k].user_override = _UNSET_SENTINEL
7277

78+
def test_env_name_semantics(self):
79+
self.assertTrue(config.e_env_default)
80+
self.assertFalse(config.e_env_default_FALSE)
81+
self.assertTrue(config.e_env_force)
82+
config.e_env_default = False
83+
self.assertFalse(config.e_env_default)
84+
config.e_env_force = False
85+
self.assertTrue(config.e_env_force)
86+
for k in config._config:
87+
config._config[k].user_override = _UNSET_SENTINEL
88+
7389
def test_save_config(self):
7490
p = config.save_config()
7591
self.assertEqual(
@@ -93,6 +109,9 @@ def test_save_config(self):
93109
"e_config": True,
94110
"e_jk": True,
95111
"e_jk_false": False,
112+
"e_env_default": True,
113+
"e_env_default_FALSE": False,
114+
"e_env_force": True,
96115
},
97116
)
98117
config.e_bool = False
@@ -123,6 +142,9 @@ def test_save_config_portable(self):
123142
"e_config": True,
124143
"e_jk": True,
125144
"e_jk_false": False,
145+
"e_env_default": True,
146+
"e_env_default_FALSE": False,
147+
"e_env_force": True,
126148
},
127149
)
128150
config.e_bool = False
@@ -152,35 +174,35 @@ def test_codegen_config(self):
152174

153175
def test_get_hash(self):
154176
self.assertEqual(
155-
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
177+
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
156178
)
157179
# Test cached value
158180
self.assertEqual(
159-
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
181+
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
160182
)
161183
self.assertEqual(
162-
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
184+
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
163185
)
164186
config._hash_digest = "fake"
165187
self.assertEqual(config.get_hash(), "fake")
166188

167189
config.e_bool = False
168190
self.assertNotEqual(
169-
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
191+
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
170192
)
171193
config.e_bool = True
172194

173195
# Test ignored values
174196
config.e_compile_ignored = False
175197
self.assertEqual(
176-
config.get_hash(), b"\xa8\xe0\x9b\xfc*\xc4P\xb5g\x1e_\x03 \x7fA\x05"
198+
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
177199
)
178200
for k in config._config:
179201
config._config[k].user_override = _UNSET_SENTINEL
180202

181203
def test_dict_copy_semantics(self):
182204
p = config.shallow_copy_dict()
183-
self.assertEqual(
205+
self.assertDictEqual(
184206
p,
185207
{
186208
"e_bool": True,
@@ -202,6 +224,9 @@ def test_dict_copy_semantics(self):
202224
"e_config": True,
203225
"e_jk": True,
204226
"e_jk_false": False,
227+
"e_env_default": True,
228+
"e_env_default_FALSE": False,
229+
"e_env_force": True,
205230
},
206231
)
207232
p2 = config.to_dict()
@@ -227,6 +252,9 @@ def test_dict_copy_semantics(self):
227252
"e_config": True,
228253
"e_jk": True,
229254
"e_jk_false": False,
255+
"e_env_default": True,
256+
"e_env_default_FALSE": False,
257+
"e_env_force": True,
230258
},
231259
)
232260
p3 = config.get_config_copy()
@@ -252,6 +280,9 @@ def test_dict_copy_semantics(self):
252280
"e_config": True,
253281
"e_jk": True,
254282
"e_jk_false": False,
283+
"e_env_default": True,
284+
"e_env_default_FALSE": False,
285+
"e_env_force": True,
255286
},
256287
)
257288

torch/testing/_internal/fake_config_module.py

+3
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
e_config = Config(default=True)
2222
e_jk = Config(justknob="does_not_exist")
2323
e_jk_false = Config(justknob="does_not_exist", default=False)
24+
e_env_default = Config(env_name_default="ENV_TRUE", default=False)
25+
e_env_default_FALSE = Config(env_name_default="ENV_FALSE", default=True)
26+
e_env_force = Config(env_name_force="ENV_TRUE", default=False)
2427

2528

2629
class nested:

torch/utils/_config_module.py

+48-2
Original file line numberDiff line numberDiff line change
@@ -28,31 +28,60 @@ class Config:
2828
This configs must be installed with install_config_module to be used
2929
3030
Precedence Order:
31+
env_name_force: If set, this environment variable overrides everything
3132
user_override: If a user sets a value (i.e. foo.bar=True), that
32-
has the highest precendance and is always respected
33+
has precedence over everything after this.
34+
env_name_default: If set, this environment variable will override everything
35+
after this.
3336
justknob: If this pytorch installation supports justknobs, that will
3437
override defaults, but will not override the user_override precendence.
3538
default: This value is the lowest precendance, and will be used if nothing is
3639
set.
3740
41+
Environment Variables:
42+
These are interpreted to be either "0" or "1" to represent true and false.
43+
3844
Arguments:
3945
justknob: the name of the feature / JK. In OSS this is unused.
4046
default: is the value to default this knob to in OSS.
47+
env_name_force: The environment variable to read that is a FORCE
48+
environment variable. I.e. it overrides everything
49+
env_name_default: The environment variable to read that changes the
50+
default behaviour. I.e. user overrides take preference.
4151
"""
4252

4353
default: Any = True
4454
justknob: Optional[str] = None
55+
env_name_default: Optional[str] = None
56+
env_name_force: Optional[str] = None
4557

46-
def __init__(self, default: Any = True, justknob: Optional[str] = None):
58+
def __init__(
59+
self,
60+
default: Any = True,
61+
justknob: Optional[str] = None,
62+
env_name_default: Optional[str] = None,
63+
env_name_force: Optional[str] = None,
64+
):
4765
# python 3.9 does not support kw_only on the dataclass :(.
4866
self.default = default
4967
self.justknob = justknob
68+
self.env_name_default = env_name_default
69+
self.env_name_force = env_name_force
5070

5171

5272
# Types saved/loaded in configs
5373
CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
5474

5575

76+
def _read_env_variable(name: str) -> Optional[bool]:
77+
value = os.environ.get(name)
78+
if value == "1":
79+
return True
80+
if value == "0":
81+
return False
82+
return None
83+
84+
5685
def install_config_module(module: ModuleType) -> None:
5786
"""
5887
Converts a module-level config into a `ConfigModule()`.
@@ -87,6 +116,7 @@ def visit(
87116
delattr(module, key)
88117
elif isinstance(value, Config):
89118
config[name] = _ConfigEntry(value)
119+
90120
if dest is module:
91121
delattr(module, key)
92122
elif isinstance(value, type):
@@ -167,10 +197,19 @@ class _ConfigEntry:
167197
user_override: Any = _UNSET_SENTINEL
168198
# The justknob to check for this config
169199
justknob: Optional[str] = None
200+
# environment variables are read at install time
201+
env_value_force: Any = _UNSET_SENTINEL
202+
env_value_default: Any = _UNSET_SENTINEL
170203

171204
def __init__(self, config: Config):
172205
self.default = config.default
173206
self.justknob = config.justknob
207+
if config.env_name_default is not None:
208+
if (env_value := _read_env_variable(config.env_name_default)) is not None:
209+
self.env_value_default = env_value
210+
if config.env_name_force is not None:
211+
if (env_value := _read_env_variable(config.env_name_force)) is not None:
212+
self.env_value_force = env_value
174213

175214

176215
class ConfigModule(ModuleType):
@@ -202,9 +241,16 @@ def __setattr__(self, name: str, value: object) -> None:
202241
def __getattr__(self, name: str) -> Any:
203242
try:
204243
config = self._config[name]
244+
245+
if config.env_value_force is not _UNSET_SENTINEL:
246+
return config.env_value_force
247+
205248
if config.user_override is not _UNSET_SENTINEL:
206249
return config.user_override
207250

251+
if config.env_value_default is not _UNSET_SENTINEL:
252+
return config.env_value_default
253+
208254
if config.justknob is not None:
209255
# JK only supports bools and ints
210256
return justknobs_check(name=config.justknob, default=config.default)

0 commit comments

Comments
 (0)