Skip to content

Commit 1b0eb94

Browse files
authoredMar 10, 2023
Fix integration tests (#245)
* publish RC version * fix up integration tests * undo version bump
1 parent c5b89a6 commit 1b0eb94

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed
 

‎truss/tests/conftest.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -226,9 +226,7 @@ def load(*args, **kwargs):
226226
227227
def preprocess(self, model_input):
228228
# Adds 1 to all
229-
return {
230-
'inputs': [value + 1 for value in model_input],
231-
}
229+
return [value + 1 for value in model_input]
232230
233231
def predict(self, model_input):
234232
return {

‎truss/tests/test_truss_handle.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,6 @@ def test_readme_generation_str_example(
8181
):
8282
th = TrussHandle(custom_model_truss_dir_with_pre_and_post_str_example)
8383
readme_contents = th.generate_readme()
84-
with open("tmp.md", "w") as f:
85-
f.write(readme_contents)
8684
readme_contents = readme_contents.replace("\n", "")
8785
correct_readme_contents = _read_readme("readme_str_example.md")
8886
assert readme_contents == correct_readme_contents
@@ -578,8 +576,8 @@ def test_control_truss_apply_patch(custom_model_control):
578576
running_hash = th.truss_hash_on_serving_container()
579577
new_model_code = """
580578
class Model:
581-
def predict(self, request):
582-
return [2 for i in request['inputs']]
579+
def predict(self, model_input):
580+
return [2 for i in model_input]
583581
"""
584582
patch_request = {
585583
"hash": "dummy",
@@ -620,8 +618,8 @@ def test_regular_truss_local_update_flow(custom_model_truss_dir):
620618
model_code_file.write(
621619
"""
622620
class Model:
623-
def predict(self, request):
624-
return [2 for i in request['inputs']]
621+
def predict(self, model_input):
622+
return [2 for i in model_input]
625623
"""
626624
)
627625
result = th.docker_predict([1], tag=tag)
@@ -681,8 +679,8 @@ def test_control_truss_local_update_flow(binary, python_version, custom_model_co
681679
def predict_with_updated_model_code():
682680
new_model_code = """
683681
class Model:
684-
def predict(self, request):
685-
return [2 for i in request['inputs']]
682+
def predict(self, model_input):
683+
return [2 for i in model_input]
686684
"""
687685
model_code_file_path = custom_model_control / "model" / "model.py"
688686
with model_code_file_path.open("w") as model_code_file:
@@ -803,8 +801,8 @@ def malformed
803801
# Should be able to fix code after
804802
good_model_code = """
805803
class Model:
806-
def predict(self, request):
807-
return [2 for i in request['inputs']]
804+
def predict(self, model_input):
805+
return [2 for i in model_input]
808806
"""
809807
with model_code_file_path.open("w") as model_code_file:
810808
model_code_file.write(good_model_code)

0 commit comments

Comments
 (0)
Please sign in to comment.