Skip to content

Commit 7d73c69

Browse files
authored
Drop legacy Kserve input validation (#219)
* Drop legacy Kserve input validation * Drop validation from model_wrapper.py * Copy model_wrapper.py to test data * Update model_wrapper.py * Update model_wrapper.py * Fix integration test * Drop unused code
1 parent 0ebf701 commit 7d73c69

File tree

8 files changed

+15
-88
lines changed

8 files changed

+15
-88
lines changed

truss/templates/server/common/truss_server.py

-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def create_application(self):
159159
),
160160
],
161161
exception_handlers={
162-
errors.InvalidInput: errors.invalid_input_handler,
163162
errors.InferenceError: errors.inference_error_handler,
164163
errors.ModelNotFound: errors.model_not_found_handler,
165164
errors.ModelNotReady: errors.model_not_ready_handler,

truss/templates/server/common/util.py

-12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from typing import Dict
2-
3-
41
def model_supports_predict_proba(model: object) -> bool:
52
if not hasattr(model, "predict_proba"):
63
return False
@@ -13,12 +10,3 @@ def model_supports_predict_proba(model: object) -> bool:
1310
except AttributeError:
1411
return False
1512
return True
16-
17-
18-
def assign_request_to_inputs_instances_after_validation(body: Dict) -> dict:
19-
# we will treat "instances" and "inputs" the same
20-
if "instances" in body and "inputs" not in body:
21-
body["inputs"] = body["instances"]
22-
elif "inputs" in body and "instances" not in body:
23-
body["instances"] = body["inputs"]
24-
return body

truss/templates/server/model_wrapper.py

-16
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,7 @@
99
from typing import Dict, Optional, Union
1010

1111
import kserve
12-
import numpy as np
1312
from cloudevents.http import CloudEvent
14-
from common.util import assign_request_to_inputs_instances_after_validation
15-
from kserve.errors import InvalidInput
1613
from kserve.grpc.grpc_predict_v2_pb2 import ModelInferRequest, ModelInferResponse
1714
from shared.secrets_resolver import SecretsResolver
1815

@@ -108,19 +105,6 @@ def try_load(self):
108105
if hasattr(self._model, "load"):
109106
self._model.load()
110107

111-
def validate(self, payload):
112-
if (
113-
"instances" in payload
114-
and not isinstance(payload["instances"], (list, np.ndarray))
115-
or "inputs" in payload
116-
and not isinstance(payload["inputs"], (list, np.ndarray))
117-
):
118-
raise InvalidInput(
119-
'Expected "instances" or "inputs" to be a list or NumPy ndarray'
120-
)
121-
122-
return assign_request_to_inputs_instances_after_validation(payload)
123-
124108
def preprocess(
125109
self,
126110
payload: Union[Dict, CloudEvent, ModelInferRequest],

truss/test_data/truss_container_fs/app/common/truss_server.py

-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,6 @@ def create_application(self):
159159
),
160160
],
161161
exception_handlers={
162-
errors.InvalidInput: errors.invalid_input_handler,
163162
errors.InferenceError: errors.inference_error_handler,
164163
errors.ModelNotFound: errors.model_not_found_handler,
165164
errors.ModelNotReady: errors.model_not_ready_handler,

truss/test_data/truss_container_fs/app/common/util.py

-9
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,3 @@ def model_supports_predict_proba(model: object) -> bool:
1010
except AttributeError:
1111
return False
1212
return True
13-
14-
15-
def assign_request_to_inputs_instances_after_validation(body: dict) -> dict:
16-
# we will treat "instances" and "inputs" the same
17-
if "instances" in body and "inputs" not in body:
18-
body["inputs"] = body["instances"]
19-
elif "inputs" in body and "instances" not in body:
20-
body["instances"] = body["inputs"]
21-
return body

truss/test_data/truss_container_fs/app/model_wrapper.py

+14-25
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,10 @@
66
from enum import Enum
77
from pathlib import Path
88
from threading import Lock, Thread
9-
from typing import Dict, Union
9+
from typing import Dict, Optional, Union
1010

1111
import kserve
12-
import numpy as np
1312
from cloudevents.http import CloudEvent
14-
from common.util import assign_request_to_inputs_instances_after_validation
15-
from kserve.errors import InvalidInput
1613
from kserve.grpc.grpc_predict_v2_pb2 import ModelInferRequest, ModelInferResponse
1714
from shared.secrets_resolver import SecretsResolver
1815

@@ -26,14 +23,15 @@ class Status(Enum):
2623
READY = 2
2724
FAILED = 3
2825

29-
_config: dict
26+
_config: Dict
3027
_model: object
3128
_load_lock: Lock = Lock()
3229
_predict_lock: Lock = Lock()
3330
_status: Status = Status.NOT_READY
3431
_logger: logging.Logger
32+
ready: bool
3533

36-
def __init__(self, config: dict):
34+
def __init__(self, config: Dict):
3735
super().__init__(MODEL_BASENAME)
3836
self._config = config
3937
self.logger = logging.getLogger(__name__)
@@ -107,41 +105,32 @@ def try_load(self):
107105
if hasattr(self._model, "load"):
108106
self._model.load()
109107

110-
def validate(self, payload):
111-
if (
112-
"instances" in payload
113-
and not isinstance(payload["instances"], (list, np.ndarray))
114-
or "inputs" in payload
115-
and not isinstance(payload["inputs"], (list, np.ndarray))
116-
):
117-
raise InvalidInput(
118-
'Expected "instances" or "inputs" to be a list or NumPy ndarray'
119-
)
120-
121-
return assign_request_to_inputs_instances_after_validation(payload)
122-
123108
def preprocess(
124109
self,
125110
payload: Union[Dict, CloudEvent, ModelInferRequest],
126-
headers: Dict[str, str] = None,
111+
headers: Optional[Dict[str, str]] = None,
127112
) -> Union[Dict, ModelInferRequest]:
128113
if not hasattr(self._model, "preprocess"):
129114
return payload
130-
return self._model.preprocess(payload)
115+
return self._model.preprocess(payload) # type: ignore
131116

132117
def postprocess(
133-
self, response: Union[Dict, ModelInferResponse], headers: Dict[str, str] = None
118+
self,
119+
response: Union[Dict, ModelInferResponse],
120+
headers: Optional[Dict[str, str]] = None,
134121
) -> Dict:
135122
if not hasattr(self._model, "postprocess"):
136123
return response
137-
return self._model.postprocess(response)
124+
return self._model.postprocess(response) # type: ignore
138125

139126
def predict(
140-
self, payload: Union[Dict, ModelInferRequest], headers: Dict[str, str] = None
127+
self,
128+
payload: Union[Dict, ModelInferRequest],
129+
headers: Optional[Dict[str, str]] = None,
141130
) -> Union[Dict, ModelInferResponse]:
142131
try:
143132
self._predict_lock.acquire()
144-
return self._model.predict(payload)
133+
return self._model.predict(payload) # type: ignore
145134
except Exception:
146135
response = {}
147136
logging.exception("Exception while running predict")

truss/tests/templates/core/server/common/test_util.py

-23
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,5 @@
11
from unittest import mock
22

3-
from truss.templates.server.common.util import (
4-
assign_request_to_inputs_instances_after_validation,
5-
)
6-
7-
8-
def test_assign_request_to_inputs_instances_after_validation():
9-
inputs_input = [1, 2, 3, 4]
10-
inputs_dict = {"inputs": inputs_input}
11-
instances_input = [5, 6, 7, 8]
12-
instances_dict = {"instances": instances_input}
13-
14-
processed_inputs = assign_request_to_inputs_instances_after_validation(inputs_dict)
15-
processed_instances = assign_request_to_inputs_instances_after_validation(
16-
instances_dict
17-
)
18-
19-
assert processed_inputs["instances"] == processed_inputs["inputs"] == inputs_input
20-
assert (
21-
processed_instances["instances"]
22-
== processed_instances["inputs"]
23-
== instances_input
24-
)
25-
263

274
def model_supports_predict_proba():
285
mock_not_predict_proba = mock.Mock(name="mock_not_predict_proba")

truss/tests/test_truss_handle.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def test_docker_predict_secrets(custom_model_truss_dir_for_secrets):
287287
LocalConfigHandler.set_secret("secret_name", "secret_value")
288288
with ensure_kill_all():
289289
try:
290-
result = th.docker_predict({"inputs": ["secret_name"]}, tag=tag)
290+
result = th.docker_predict({"instances": ["secret_name"]}, tag=tag)
291291
assert result["predictions"][0] == "secret_value"
292292
finally:
293293
LocalConfigHandler.remove_secret("secret_name")

0 commit comments

Comments
 (0)