Skip to content

Commit 27dced4

Browse files
Multinode Cleanup (#1360)
* nodecount cleanup * adds node count to dict marshalling * updates tests * updates test * remove unnecessary things * pr review
1 parent f7be3e0 commit 27dced4

File tree

5 files changed

+43
-18
lines changed

5 files changed

+43
-18
lines changed

truss/base/truss_config.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from truss.base.validation import (
2020
validate_cpu_spec,
2121
validate_memory_spec,
22+
validate_node_count,
2223
validate_python_executable_path,
2324
validate_secret_name,
2425
validate_secret_to_path_mapping,
@@ -43,7 +44,6 @@
4344
DEFAULT_CPU = "1"
4445
DEFAULT_MEMORY = "2Gi"
4546
DEFAULT_USE_GPU = False
46-
DEFAULT_NODE_COUNT = 1
4747

4848
DEFAULT_BLOB_BACKEND = HTTP_PUBLIC_BLOB_BACKEND
4949

@@ -261,7 +261,7 @@ class Resources:
261261
memory: str = DEFAULT_MEMORY
262262
use_gpu: bool = DEFAULT_USE_GPU
263263
accelerator: AcceleratorSpec = field(default_factory=AcceleratorSpec)
264-
node_count: int = DEFAULT_NODE_COUNT
264+
node_count: Optional[int] = None
265265

266266
@staticmethod
267267
def from_dict(d):
@@ -273,24 +273,27 @@ def from_dict(d):
273273
use_gpu = d.get("use_gpu", DEFAULT_USE_GPU)
274274
if accelerator.accelerator is not None:
275275
use_gpu = True
276-
# TODO[rcano]: add validation for node count
277-
node_count = d.get("node_count", DEFAULT_NODE_COUNT)
278-
279-
return Resources(
280-
cpu=cpu,
281-
memory=memory,
282-
use_gpu=use_gpu,
283-
accelerator=accelerator,
284-
node_count=node_count,
285-
)
276+
277+
r = Resources(cpu=cpu, memory=memory, use_gpu=use_gpu, accelerator=accelerator)
278+
279+
# only add node_count if not None. This helps keep
280+
# config generated by truss init concise.
281+
node_count = d.get("node_count")
282+
validate_node_count(node_count)
283+
r.node_count = node_count
284+
285+
return r
286286

287287
def to_dict(self):
288-
return {
288+
d = {
289289
"cpu": self.cpu,
290290
"memory": self.memory,
291291
"use_gpu": self.use_gpu,
292292
"accelerator": self.accelerator.to_str(),
293293
}
294+
if self.node_count is not None:
295+
d["node_count"] = self.node_count
296+
return d
294297

295298

296299
@dataclass
@@ -775,7 +778,7 @@ def _handle_env_vars(env_vars: Dict[str, Any]) -> Dict[str, str]:
775778

776779

777780
DATACLASS_TO_REQ_KEYS_MAP = {
778-
Resources: {"accelerator", "cpu", "memory", "use_gpu", "node_count"},
781+
Resources: {"accelerator", "cpu", "memory", "use_gpu"},
779782
Runtime: {"predict_concurrency"},
780783
Build: {"model_server"},
781784
TrussConfig: {

truss/base/validation.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
import re
33
from pathlib import PurePosixPath
4-
from typing import Dict, Pattern
4+
from typing import Any, Dict, Pattern
55

66
from truss.base.constants import REGISTRY_BUILD_SECRET_PREFIX
77
from truss.base.errors import ValidationError
@@ -122,3 +122,17 @@ def validate_python_executable_path(path: str) -> None:
122122
raise ValidationError(
123123
f"Invalid relative python executable path {path}. Provide an absolute path"
124124
)
125+
126+
127+
def validate_node_count(node_count: Any) -> None:
128+
fieldpath = "resources.node_count"
129+
if node_count is None:
130+
return None
131+
if not isinstance(node_count, int):
132+
raise ValidationError(
133+
f"{fieldpath} must be a postiive integer. Got {node_count} of type '{type(node_count)}'"
134+
)
135+
if node_count < 1:
136+
raise ValidationError(
137+
f"{fieldpath} must be a positive integer. Got {node_count}."
138+
)

truss/tests/conftest.py

-1
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,6 @@ def default_config() -> Dict[str, Any]:
731731
"cpu": "1",
732732
"memory": "2Gi",
733733
"use_gpu": False,
734-
"node_count": 1,
735734
},
736735
"secrets": {},
737736
"system_packages": [],

truss/tests/test_config.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,17 @@
7878
"accelerator": "A10G:4",
7979
},
8080
),
81+
(
82+
{"node_count": 2},
83+
Resources(node_count=2),
84+
{
85+
"cpu": DEFAULT_CPU,
86+
"memory": DEFAULT_MEMORY,
87+
"use_gpu": False,
88+
"accelerator": None,
89+
"node_count": 2,
90+
},
91+
),
8192
],
8293
)
8394
def test_parse_resources(input_dict, expect_resources, output_dict):
@@ -170,7 +181,6 @@ def test_default_config_not_crowded_end_to_end():
170181
accelerator: null
171182
cpu: '1'
172183
memory: 2Gi
173-
node_count: 1
174184
use_gpu: false
175185
secrets: {}
176186
system_packages: []

truss/tests/test_truss_handle.py

-1
Original file line numberDiff line numberDiff line change
@@ -806,7 +806,6 @@ def generate_default_config():
806806
"cpu": "1",
807807
"memory": "2Gi",
808808
"use_gpu": False,
809-
"node_count": 1,
810809
},
811810
"secrets": {},
812811
"system_packages": [],

0 commit comments

Comments
 (0)