Skip to content

Commit d08b512

Browse files
authored
Add is_healthy to truss (#1283)
* first poc * change logs to be time elapsed * reset _first_health_check_failure on is_ready success * is_ready fail fast on load failure * move logging to after load finished, update log text * remove comment * cr * Add `is_ready` to chains (#1289) * first pass at is_ready in chains * revert example chainlet * bump ctx builder * address comments + add tests * add test chain * fix test * cr fixes * fix test * cr * fix formatting * is_ready -> is_healthy * more refactoring * fix assert * fix couple more tests * fix for extra args * marius cr * fix linting * new loaded endpoint + customize retries * fix test_e2e * fix tests * filter out loaded logs * move reroute logic into helper * docstrings + wait 10s max * bump ctx builder
1 parent 2a2d054 commit d08b512

File tree

16 files changed

+588
-23
lines changed

16 files changed

+588
-23
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.9.59rc017"
3+
version = "0.9.59rc018"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import truss_chains as chains
2+
3+
4+
class CustomHealthChecks(chains.ChainletBase):
5+
"""Implements custom health checks."""
6+
7+
def __init__(self):
8+
self._should_succeed_health_checks = True
9+
10+
def is_healthy(self) -> bool:
11+
return self._should_succeed_health_checks
12+
13+
async def run_remote(self, fail: bool) -> str:
14+
if fail:
15+
self._should_succeed_health_checks = False
16+
else:
17+
self._should_succeed_health_checks = True
18+
return f"health checks will {'succeed' if self._should_succeed_health_checks else 'fail'}"

truss-chains/tests/test_e2e.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55

66
import pytest
77
import requests
8-
from truss.tests.test_testing_utilities_for_other_tests import ensure_kill_all
8+
from truss.tests.test_testing_utilities_for_other_tests import (
9+
ensure_kill_all,
10+
get_container_logs_from_prefix,
11+
)
912
from truss.truss_handle.build import load
1013

1114
from truss_chains import definitions, framework, public_api, utils
@@ -275,3 +278,39 @@ def test_traditional_truss():
275278
)
276279
assert response.status_code == 200
277280
assert response.json() == 5
281+
282+
283+
@pytest.mark.integration
284+
def test_custom_health_checks_chain():
285+
with ensure_kill_all():
286+
chain_root = TEST_ROOT / "custom_health_checks" / "custom_health_checks.py"
287+
with framework.import_target(chain_root, "CustomHealthChecks") as entrypoint:
288+
service = deployment_client.push(
289+
entrypoint,
290+
options=definitions.PushOptionsLocalDocker(
291+
chain_name="integration-test-custom-health-checks",
292+
only_generate_trusses=False,
293+
use_local_chains_src=True,
294+
),
295+
)
296+
297+
assert service is not None
298+
health_check_url = service.run_remote_url.split(":predict")[0]
299+
300+
response = service.run_remote({"fail": False})
301+
assert response.status_code == 200
302+
response = requests.get(health_check_url)
303+
response.status_code == 200
304+
container_logs = get_container_logs_from_prefix(entrypoint.name)
305+
assert "Health check failed." not in container_logs
306+
307+
# Start failing health checks
308+
response = service.run_remote({"fail": True})
309+
response = requests.get(health_check_url)
310+
assert response.status_code == 503
311+
container_logs = get_container_logs_from_prefix(entrypoint.name)
312+
assert container_logs.count("Health check failed.") == 1
313+
response = requests.get(health_check_url)
314+
assert response.status_code == 503
315+
container_logs = get_container_logs_from_prefix(entrypoint.name)
316+
assert container_logs.count("Health check failed.") == 2

truss-chains/tests/test_framework.py

+91-2
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def run_remote(self, arg: "SomeModel") -> None:
307307
def test_raises_endpoint_no_method():
308308
match = (
309309
rf"{TEST_FILE}:\d+ \(StaticMethod\.run_remote\) \[kind: TYPE_ERROR\].*"
310-
r"Endpoint must be a method"
310+
r"`run_remote` must be a method"
311311
)
312312

313313
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
@@ -321,7 +321,7 @@ def run_remote() -> None:
321321
def test_raises_endpoint_no_method_arg():
322322
match = (
323323
rf"{TEST_FILE}:\d+ \(StaticMethod\.run_remote\) \[kind: TYPE_ERROR\].*"
324-
r"Endpoint must be a method"
324+
r"`run_remote` must be a method"
325325
)
326326

327327
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
@@ -579,3 +579,92 @@ def test_raises_iterator_no_arg():
579579
class IteratorNoArg(chains.ChainletBase):
580580
async def run_remote(self) -> AsyncIterator:
581581
yield "123"
582+
583+
584+
def test_raises_is_healthy_not_a_method():
585+
match = rf"{TEST_FILE}:\d+ \(IsHealthyNotMethod\) \[kind: TYPE_ERROR\].* `is_healthy` must be a method."
586+
587+
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
588+
589+
class IsHealthyNotMethod(chains.ChainletBase):
590+
is_healthy: int = 3
591+
592+
async def run_remote(self) -> str:
593+
return ""
594+
595+
596+
def test_raises_is_healthy_no_arg():
597+
match = (
598+
rf"{TEST_FILE}:\d+ \(IsHealthyNoArg\.is_healthy\) \[kind: TYPE_ERROR\].*"
599+
r"`is_healthy` must be a method, i.e. with `self` as first argument. Got function with no arguments."
600+
)
601+
602+
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
603+
604+
class IsHealthyNoArg(chains.ChainletBase):
605+
async def is_healthy() -> bool:
606+
return True
607+
608+
async def run_remote(self) -> str:
609+
return ""
610+
611+
612+
def test_raises_is_healthy_first_arg_not_self():
613+
match = (
614+
rf"{TEST_FILE}:\d+ \(IsHealthyNoSelfArg\.is_healthy\) \[kind: TYPE_ERROR\].*"
615+
r"`is_healthy` must be a method, i.e. with `self` as first argument. Got `hi` as first argument."
616+
)
617+
618+
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
619+
620+
class IsHealthyNoSelfArg(chains.ChainletBase):
621+
def is_healthy(hi) -> bool:
622+
return True
623+
624+
async def run_remote(self) -> str:
625+
return ""
626+
627+
628+
def test_raises_is_healthy_multiple_args():
629+
match = rf"{TEST_FILE}:\d+ \(IsHealthyManyArgs\.is_healthy\) \[kind: TYPE_ERROR\].* `is_healthy` must have only one argument: `self`."
630+
631+
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
632+
633+
class IsHealthyManyArgs(chains.ChainletBase):
634+
def is_healthy(self, hi) -> bool:
635+
return True
636+
637+
async def run_remote(self) -> str:
638+
return ""
639+
640+
641+
def test_raises_is_healthy_not_type_annotated():
642+
match = (
643+
rf"{TEST_FILE}:\d+ \(IsHealthyNotTyped\.is_healthy\) \[kind: IO_TYPE_ERROR\].*"
644+
r"Return value of health check must be type annotated. Got:\n\tis_healthy\(self\) -> !MISSING!"
645+
)
646+
647+
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
648+
649+
class IsHealthyNotTyped(chains.ChainletBase):
650+
def is_healthy(self):
651+
return True
652+
653+
async def run_remote(self) -> str:
654+
return ""
655+
656+
657+
def test_raises_is_healthy_not_boolean_typed():
658+
match = (
659+
rf"{TEST_FILE}:\d+ \(IsHealthyNotBoolTyped\.is_healthy\) \[kind: IO_TYPE_ERROR\].*"
660+
r"Return value of health check must be a boolean. Got:\n\tis_healthy\(self\) -> str -> <class 'str'>"
661+
)
662+
663+
with pytest.raises(definitions.ChainsUsageError, match=match), _raise_errors():
664+
665+
class IsHealthyNotBoolTyped(chains.ChainletBase):
666+
def is_healthy(self) -> str: # type: ignore[misc]
667+
return "not ready"
668+
669+
async def run_remote(self) -> str:
670+
return ""

truss-chains/truss_chains/definitions.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,10 @@
3434
GENERATED_CODE_DIR = ".chains_generated"
3535
DYNAMIC_CHAINLET_CONFIG_KEY = "dynamic_chainlet_config"
3636
OTEL_TRACE_PARENT_HEADER_KEY = "traceparent"
37-
# Below arg names must correspond to `definitions.ABCChainlet`.
3837
RUN_REMOTE_METHOD_NAME = "run_remote" # Chainlet method name exposed as endpoint.
3938
MODEL_ENDPOINT_METHOD_NAME = "predict" # Model method name exposed as endpoint.
39+
HEALTH_CHECK_METHOD_NAME = "is_healthy"
40+
# Below arg names must correspond to `definitions.ABCChainlet`.
4041
CONTEXT_ARG_NAME = "context" # Referring to Chainlets `__init__` signature.
4142
SELF_ARG_NAME = "self"
4243
REMOTE_CONFIG_NAME = "remote_config"
@@ -625,12 +626,18 @@ def display_name(self) -> str:
625626
return self.chainlet_cls.display_name
626627

627628

629+
class HealthCheckAPIDescriptor(SafeModelNonSerializable):
630+
name: str = HEALTH_CHECK_METHOD_NAME
631+
is_async: bool
632+
633+
628634
class ChainletAPIDescriptor(SafeModelNonSerializable):
629635
chainlet_cls: Type[ABCChainlet]
630636
src_path: str
631637
has_context: bool
632638
dependencies: Mapping[str, DependencyDescriptor]
633639
endpoint: EndpointAPIDescriptor
640+
health_check: Optional[HealthCheckAPIDescriptor]
634641

635642
def __hash__(self) -> int:
636643
return hash(self.chainlet_cls)

truss-chains/truss_chains/deployment/code_gen.py

+18
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,20 @@ def _gen_load_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _So
457457
return _Source(src=src, imports=imports)
458458

459459

460+
def _gen_health_check_src(
461+
health_check: definitions.HealthCheckAPIDescriptor,
462+
) -> _Source:
463+
"""Generates AST for the `is_healthy` method of the truss model."""
464+
def_str = "async def" if health_check.is_async else "def"
465+
maybe_await = "await " if health_check.is_async else ""
466+
src = (
467+
f"{def_str} is_healthy(self) -> Optional[bool]:\n"
468+
f"""{_indent('if hasattr(self, "_chainlet"):')}"""
469+
f"""{_indent(f"return {maybe_await}self._chainlet.is_healthy()")}"""
470+
)
471+
return _Source(src=src)
472+
473+
460474
def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) -> _Source:
461475
"""Generates AST for the `predict` method of the truss model."""
462476
imports: set[str] = {
@@ -538,6 +552,10 @@ def _gen_truss_chainlet_model(
538552
libcst.parse_statement(predict_src.src),
539553
]
540554

555+
if chainlet_descriptor.health_check is not None:
556+
health_check_src = _gen_health_check_src(chainlet_descriptor.health_check)
557+
new_body.extend([libcst.parse_statement(health_check_src.src)])
558+
541559
user_chainlet_ref = _gen_chainlet_import_and_ref(chainlet_descriptor)
542560
imports.update(user_chainlet_ref.imports)
543561

truss-chains/truss_chains/framework.py

+71-8
Original file line numberDiff line numberDiff line change
@@ -339,24 +339,29 @@ def _validate_streaming_output_type(
339339
)
340340

341341

342-
def _validate_endpoint_params(
343-
params: list[inspect.Parameter], location: _ErrorLocation
344-
) -> list[definitions.InputArg]:
342+
def _validate_method_signature(
343+
method_name: str, location: _ErrorLocation, params: list[inspect.Parameter]
344+
) -> None:
345345
if len(params) == 0:
346346
_collect_error(
347-
f"`Endpoint must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as "
347+
f"`{method_name}` must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as "
348348
"first argument. Got function with no arguments.",
349349
_ErrorKind.TYPE_ERROR,
350350
location,
351351
)
352-
return []
353-
if params[0].name != definitions.SELF_ARG_NAME:
352+
elif params[0].name != definitions.SELF_ARG_NAME:
354353
_collect_error(
355-
f"`Endpoint must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as "
354+
f"`{method_name}` must be a method, i.e. with `{definitions.SELF_ARG_NAME}` as "
356355
f"first argument. Got `{params[0].name}` as first argument.",
357356
_ErrorKind.TYPE_ERROR,
358357
location,
359358
)
359+
360+
361+
def _validate_endpoint_params(
362+
params: list[inspect.Parameter], location: _ErrorLocation
363+
) -> list[definitions.InputArg]:
364+
_validate_method_signature(definitions.RUN_REMOTE_METHOD_NAME, location, params)
360365
input_args = []
361366
for param in params[1:]: # Skip self argument.
362367
if param.annotation == inspect.Parameter.empty:
@@ -434,7 +439,7 @@ def _validate_and_describe_endpoint(
434439
```
435440
436441
* The name must be `run_remote` for Chainlets, or `predict` for Models.
437-
* It can be sync or async or def.
442+
* It can be sync or async def.
438443
* The number and names of parameters are arbitrary, both positional and named
439444
parameters are ok.
440445
* All parameters and the return value must have type annotations. See
@@ -742,6 +747,63 @@ def _validate_remote_config(
742747
)
743748

744749

750+
def _validate_health_check(
751+
cls: Type[definitions.ABCChainlet], location: _ErrorLocation
752+
) -> Optional[definitions.HealthCheckAPIDescriptor]:
753+
"""The `is_healthy` method of a Chainlet must have the following signature:
754+
```
755+
[async] def is_healthy(self) -> bool:
756+
```
757+
* The name must be `is_healthy`.
758+
* It can be sync or async def.
759+
* Must not define any parameters other than `self`.
760+
* Must return a boolean.
761+
"""
762+
if not hasattr(cls, definitions.HEALTH_CHECK_METHOD_NAME):
763+
return None
764+
765+
health_check_method = getattr(cls, definitions.HEALTH_CHECK_METHOD_NAME)
766+
if not inspect.isfunction(health_check_method):
767+
_collect_error(
768+
f"`{definitions.HEALTH_CHECK_METHOD_NAME}` must be a method.",
769+
_ErrorKind.TYPE_ERROR,
770+
location,
771+
)
772+
return None
773+
774+
line = inspect.getsourcelines(health_check_method)[1]
775+
location = location.model_copy(
776+
update={"line": line, "method_name": definitions.HEALTH_CHECK_METHOD_NAME}
777+
)
778+
is_async = inspect.iscoroutinefunction(health_check_method)
779+
signature = inspect.signature(health_check_method)
780+
params = list(signature.parameters.values())
781+
_validate_method_signature(definitions.HEALTH_CHECK_METHOD_NAME, location, params)
782+
if len(params) > 1:
783+
_collect_error(
784+
f"`{definitions.HEALTH_CHECK_METHOD_NAME}` must have only one argument: `{definitions.SELF_ARG_NAME}`.",
785+
_ErrorKind.TYPE_ERROR,
786+
location,
787+
)
788+
if signature.return_annotation == inspect.Parameter.empty:
789+
_collect_error(
790+
"Return value of health check must be type annotated. Got:\n"
791+
f"\t{location.method_name}{signature} -> !MISSING!",
792+
_ErrorKind.IO_TYPE_ERROR,
793+
location,
794+
)
795+
return None
796+
if signature.return_annotation is not bool:
797+
_collect_error(
798+
"Return value of health check must be a boolean. Got:\n"
799+
f"\t{location.method_name}{signature} -> {signature.return_annotation}",
800+
_ErrorKind.IO_TYPE_ERROR,
801+
location,
802+
)
803+
804+
return definitions.HealthCheckAPIDescriptor(is_async=is_async)
805+
806+
745807
def validate_and_register_cls(cls: Type[definitions.ABCChainlet]) -> None:
746808
"""Note that validation errors will only be collected, not raised, and Chainlets.
747809
with issues, are still added to the registry. Use `raise_validation_errors` to
@@ -759,6 +821,7 @@ def validate_and_register_cls(cls: Type[definitions.ABCChainlet]) -> None:
759821
has_context=init_validator.has_context,
760822
endpoint=_validate_and_describe_endpoint(cls, location),
761823
src_path=src_path,
824+
health_check=_validate_health_check(cls, location),
762825
)
763826
logging.debug(
764827
f"Descriptor for {cls}:\n{pprint.pformat(chainlet_descriptor, indent=4)}\n"

truss-chains/truss_chains/remote_chainlet/model_skeleton.py

+5
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ def __init__(
4747
# side_effect=stub.factory(SideEffectOnlySubclass, self._context),
4848
# )
4949
#
50+
# If chainlet implements is_healthy:
51+
# def is_healthy(self) -> Optional[bool]:
52+
# if hasattr(self, "_chainlet"):
53+
# return self._chainlet.is_healthy()
54+
#
5055
# def predict(
5156
# self, inputs: TextToNumInput, request: starlette.requests.Request
5257
# ) -> TextToNumOutput:

0 commit comments

Comments
 (0)