Skip to content

Commit 1a55494

Browse files
authored
add_python_req modified to handle duplicates, change the version if exists (#231)
1 parent b97490b commit 1a55494

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

truss/tests/test_truss_handle.py

+15
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,21 @@ def test_update_requirements(custom_model_truss_dir_with_pre_and_post):
427427
sc_requirements = th.spec.requirements
428428
assert sc_requirements == requirements
429429

430+
def test_add_python_requirements_new(custom_model_truss_dir_with_pre_and_post):
431+
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
432+
python_package_to_add = "tensorflow==2.3.1"
433+
th.add_python_requirement(python_package_to_add)
434+
sc_requirements = th.spec.requirements
435+
assert python_package_to_add in sc_requirements
436+
437+
def test_add_python_requirements_version_change(custom_model_truss_dir_with_pre_and_post):
438+
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
439+
th.add_python_requirement("tensorflow==2.3.1")
440+
441+
python_package_to_change_version = "tensorflow==2.3.5"
442+
th.add_python_requirement(python_package_to_change_version)
443+
sc_requirements = th.spec.requirements
444+
assert python_package_to_change_version.split('==')[1]==sc_requirements[0].split('==')[1]
430445

431446
def test_update_requirements_from_file(
432447
custom_model_truss_dir_with_pre_and_post, tmp_path

truss/truss_handle.py

+14-6
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1111
from urllib.error import HTTPError
1212

13+
import pkg_resources
1314
import requests
1415
import yaml
1516
from requests import exceptions
@@ -79,6 +80,10 @@
7980
logger.addHandler(logging.StreamHandler(sys.stdout))
8081

8182

83+
def _python_req_name(python_requirement: str) -> str:
84+
return pkg_resources.Requirement.parse(python_requirement).name
85+
86+
8287
class TrussHandle:
8388
def __init__(self, truss_dir: Path, validate: bool = True) -> None:
8489
self._truss_dir = truss_dir
@@ -106,7 +111,6 @@ def spec(self) -> TrussSpec:
106111
def _wait_for_predict(
107112
model_base_url: str, request: Dict, binary: bool = False
108113
) -> Response:
109-
110114
url = f"{model_base_url}/v1/models/model:predict"
111115

112116
if binary:
@@ -397,11 +401,15 @@ def training_docker_build_setup(self, build_dir: Optional[Path] = None):
397401

398402
def add_python_requirement(self, python_requirement: str):
399403
"""Add a python requirement to truss model's config."""
400-
self._update_config(
401-
lambda conf: replace(
402-
conf, requirements=[*conf.requirements, python_requirement]
403-
)
404-
)
404+
405+
# Parse the added python requirements
406+
input_python_req_name = _python_req_name(python_requirement)
407+
new_reqs = [
408+
req for req in self._spec.config.requirements
409+
if _python_req_name(req) != input_python_req_name
410+
]
411+
new_reqs.append(python_requirement)
412+
self._update_config(lambda conf: replace(conf, requirements=new_reqs))
405413

406414
def remove_python_requirement(self, python_requirement: str):
407415
"""Remove a python requirement to truss model's config.

0 commit comments

Comments
 (0)