Skip to content

Commit d9a461b

Browse files
authored
extend patch support for truss dir patch applier (#553)
* extend patch support for truss dir patch applier * bump version * add missing continue * correct version * always overwrite config when applicable * add tests, remove redundant update
1 parent 0c4c2a3 commit d9a461b

File tree

3 files changed

+35
-10
lines changed

3 files changed

+35
-10
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.5.8"
3+
version = "0.5.9rc1"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/patch/truss_dir_patch_applier.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from dataclasses import replace
32
from pathlib import Path
43
from typing import List
54

@@ -16,6 +15,9 @@
1615
)
1716
from truss.templates.control.control.helpers.types import (
1817
Action,
18+
ConfigPatch,
19+
EnvVarPatch,
20+
ExternalDataPatch,
1921
ModelCodePatch,
2022
Patch,
2123
PythonRequirementPatch,
@@ -40,6 +42,7 @@ def __call__(self, patches: List[Patch]):
4042
# Aggregate config patches and apply at end
4143
reqs = reqs_by_name(self._truss_config.requirements)
4244
pkgs = system_packages_set(self._truss_config.system_packages)
45+
new_config = self._truss_config
4346
for patch in patches:
4447
self._logger.debug(f"Applying patch {patch.to_dict()}")
4548
action = patch.body.action
@@ -67,11 +70,15 @@ def __call__(self, patches: List[Patch]):
6770
if action == Action.ADD or Action.UPDATE:
6871
pkgs.add(pkg)
6972
continue
73+
# Each of EnvVarPatch and ExternalDataPatch can be expressed through an overwrite of the config,
74+
# handled below
75+
if isinstance(patch.body, EnvVarPatch):
76+
continue
77+
if isinstance(patch.body, ExternalDataPatch):
78+
continue
79+
if isinstance(patch.body, ConfigPatch):
80+
new_config = TrussConfig.from_dict(patch.body.config)
81+
continue
7082
raise UnsupportedPatch(f"Unknown patch type {patch.type}")
7183

72-
self._truss_config = replace(
73-
self._truss_config,
74-
requirements=list(reqs.values()),
75-
system_packages=list(pkgs),
76-
)
77-
self._truss_config.write_to_yaml_file(self._truss_config_path)
84+
new_config.write_to_yaml_file(self._truss_config_path)

truss/tests/patch/test_truss_dir_patch_applier.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import logging
22
from pathlib import Path
33

4+
import yaml
45
from truss.patch.truss_dir_patch_applier import TrussDirPatchApplier
56
from truss.templates.control.control.helpers.types import (
67
Action,
8+
ConfigPatch,
79
ModelCodePatch,
810
Patch,
911
PatchType,
@@ -33,6 +35,8 @@ def test_model_code_patch(custom_model_truss_dir: Path):
3335
def test_python_requirement_patch(custom_model_truss_dir: Path):
3436
req = "git+https://github.com/huggingface/transformers.git"
3537
applier = TrussDirPatchApplier(custom_model_truss_dir, TEST_LOGGER)
38+
config = yaml.safe_load((custom_model_truss_dir / "config.yaml").read_text())
39+
config["requirements"] = [req]
3640
applier(
3741
[
3842
Patch(
@@ -41,7 +45,13 @@ def test_python_requirement_patch(custom_model_truss_dir: Path):
4145
action=Action.ADD,
4246
requirement=req,
4347
),
44-
)
48+
),
49+
Patch(
50+
type=PatchType.CONFIG,
51+
body=ConfigPatch(
52+
action=Action.UPDATE, config=config, path="config.yaml"
53+
),
54+
),
4555
]
4656
)
4757
assert TrussConfig.from_yaml(
@@ -51,15 +61,23 @@ def test_python_requirement_patch(custom_model_truss_dir: Path):
5161

5262
def test_system_requirement_patch(custom_model_truss_dir: Path):
5363
applier = TrussDirPatchApplier(custom_model_truss_dir, TEST_LOGGER)
64+
config = yaml.safe_load((custom_model_truss_dir / "config.yaml").read_text())
65+
config["system_packages"] = ["curl"]
5466
applier(
5567
[
68+
Patch(
69+
type=PatchType.CONFIG,
70+
body=ConfigPatch(
71+
action=Action.UPDATE, config=config, path="config.yaml"
72+
),
73+
),
5674
Patch(
5775
type=PatchType.SYSTEM_PACKAGE,
5876
body=SystemPackagePatch(
5977
action=Action.ADD,
6078
package="curl",
6179
),
62-
)
80+
),
6381
]
6482
)
6583
assert TrussConfig.from_yaml(

0 commit comments

Comments
 (0)