@@ -339,24 +339,29 @@ def _validate_streaming_output_type(
339
339
)
340
340
341
341
342
- def _validate_endpoint_params (
343
- params : list [inspect .Parameter ], location : _ErrorLocation
344
- ) -> list [ definitions . InputArg ] :
342
+ def _validate_method_signature (
343
+ method_name : str , location : _ErrorLocation , params : list [inspect .Parameter ]
344
+ ) -> None :
345
345
if len (params ) == 0 :
346
346
_collect_error (
347
- f"`Endpoint must be a method, i.e. with `{ definitions .SELF_ARG_NAME } ` as "
347
+ f"`{ method_name } ` must be a method, i.e. with `{ definitions .SELF_ARG_NAME } ` as "
348
348
"first argument. Got function with no arguments." ,
349
349
_ErrorKind .TYPE_ERROR ,
350
350
location ,
351
351
)
352
- return []
353
- if params [0 ].name != definitions .SELF_ARG_NAME :
352
+ elif params [0 ].name != definitions .SELF_ARG_NAME :
354
353
_collect_error (
355
- f"`Endpoint must be a method, i.e. with `{ definitions .SELF_ARG_NAME } ` as "
354
+ f"`{ method_name } ` must be a method, i.e. with `{ definitions .SELF_ARG_NAME } ` as "
356
355
f"first argument. Got `{ params [0 ].name } ` as first argument." ,
357
356
_ErrorKind .TYPE_ERROR ,
358
357
location ,
359
358
)
359
+
360
+
361
+ def _validate_endpoint_params (
362
+ params : list [inspect .Parameter ], location : _ErrorLocation
363
+ ) -> list [definitions .InputArg ]:
364
+ _validate_method_signature (definitions .RUN_REMOTE_METHOD_NAME , location , params )
360
365
input_args = []
361
366
for param in params [1 :]: # Skip self argument.
362
367
if param .annotation == inspect .Parameter .empty :
@@ -434,7 +439,7 @@ def _validate_and_describe_endpoint(
434
439
```
435
440
436
441
* The name must be `run_remote` for Chainlets, or `predict` for Models.
437
- * It can be sync or async or def.
442
+ * It can be sync or async def.
438
443
* The number and names of parameters are arbitrary, both positional and named
439
444
parameters are ok.
440
445
* All parameters and the return value must have type annotations. See
@@ -742,6 +747,63 @@ def _validate_remote_config(
742
747
)
743
748
744
749
750
+ def _validate_health_check (
751
+ cls : Type [definitions .ABCChainlet ], location : _ErrorLocation
752
+ ) -> Optional [definitions .HealthCheckAPIDescriptor ]:
753
+ """The `is_healthy` method of a Chainlet must have the following signature:
754
+ ```
755
+ [async] def is_healthy(self) -> bool:
756
+ ```
757
+ * The name must be `is_healthy`.
758
+ * It can be sync or async def.
759
+ * Must not define any parameters other than `self`.
760
+ * Must return a boolean.
761
+ """
762
+ if not hasattr (cls , definitions .HEALTH_CHECK_METHOD_NAME ):
763
+ return None
764
+
765
+ health_check_method = getattr (cls , definitions .HEALTH_CHECK_METHOD_NAME )
766
+ if not inspect .isfunction (health_check_method ):
767
+ _collect_error (
768
+ f"`{ definitions .HEALTH_CHECK_METHOD_NAME } ` must be a method." ,
769
+ _ErrorKind .TYPE_ERROR ,
770
+ location ,
771
+ )
772
+ return None
773
+
774
+ line = inspect .getsourcelines (health_check_method )[1 ]
775
+ location = location .model_copy (
776
+ update = {"line" : line , "method_name" : definitions .HEALTH_CHECK_METHOD_NAME }
777
+ )
778
+ is_async = inspect .iscoroutinefunction (health_check_method )
779
+ signature = inspect .signature (health_check_method )
780
+ params = list (signature .parameters .values ())
781
+ _validate_method_signature (definitions .HEALTH_CHECK_METHOD_NAME , location , params )
782
+ if len (params ) > 1 :
783
+ _collect_error (
784
+ f"`{ definitions .HEALTH_CHECK_METHOD_NAME } ` must have only one argument: `{ definitions .SELF_ARG_NAME } `." ,
785
+ _ErrorKind .TYPE_ERROR ,
786
+ location ,
787
+ )
788
+ if signature .return_annotation == inspect .Parameter .empty :
789
+ _collect_error (
790
+ "Return value of health check must be type annotated. Got:\n "
791
+ f"\t { location .method_name } { signature } -> !MISSING!" ,
792
+ _ErrorKind .IO_TYPE_ERROR ,
793
+ location ,
794
+ )
795
+ return None
796
+ if signature .return_annotation is not bool :
797
+ _collect_error (
798
+ "Return value of health check must be a boolean. Got:\n "
799
+ f"\t { location .method_name } { signature } -> { signature .return_annotation } " ,
800
+ _ErrorKind .IO_TYPE_ERROR ,
801
+ location ,
802
+ )
803
+
804
+ return definitions .HealthCheckAPIDescriptor (is_async = is_async )
805
+
806
+
745
807
def validate_and_register_cls (cls : Type [definitions .ABCChainlet ]) -> None :
746
808
"""Note that validation errors will only be collected, not raised, and Chainlets.
747
809
with issues, are still added to the registry. Use `raise_validation_errors` to
@@ -759,6 +821,7 @@ def validate_and_register_cls(cls: Type[definitions.ABCChainlet]) -> None:
759
821
has_context = init_validator .has_context ,
760
822
endpoint = _validate_and_describe_endpoint (cls , location ),
761
823
src_path = src_path ,
824
+ health_check = _validate_health_check (cls , location ),
762
825
)
763
826
logging .debug (
764
827
f"Descriptor for { cls } :\n { pprint .pformat (chainlet_descriptor , indent = 4 )} \n "
0 commit comments