Skip to content

Commit 56038cb

Browse files
authored
Reverted #231 (#252)
* Reverted basetenlabs/baseten#5830. * Fix issue from merge. * Bump truss version.
1 parent 730fb33 commit 56038cb

File tree

3 files changed

+6
-39
lines changed

3 files changed

+6
-39
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.4.0"
3+
version = "0.4.1"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/tests/test_truss_handle.py

-23
Original file line numberDiff line numberDiff line change
@@ -416,29 +416,6 @@ def test_update_requirements(custom_model_truss_dir_with_pre_and_post):
416416
assert sc_requirements == requirements
417417

418418

419-
def test_add_python_requirements_new(custom_model_truss_dir_with_pre_and_post):
420-
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
421-
python_package_to_add = "tensorflow==2.3.1"
422-
th.add_python_requirement(python_package_to_add)
423-
sc_requirements = th.spec.requirements
424-
assert python_package_to_add in sc_requirements
425-
426-
427-
def test_add_python_requirements_version_change(
428-
custom_model_truss_dir_with_pre_and_post,
429-
):
430-
th = TrussHandle(custom_model_truss_dir_with_pre_and_post)
431-
th.add_python_requirement("tensorflow==2.3.1")
432-
433-
python_package_to_change_version = "tensorflow==2.3.5"
434-
th.add_python_requirement(python_package_to_change_version)
435-
sc_requirements = th.spec.requirements
436-
assert (
437-
python_package_to_change_version.split("==")[1]
438-
== sc_requirements[0].split("==")[1]
439-
)
440-
441-
442419
def test_update_requirements_from_file(
443420
custom_model_truss_dir_with_pre_and_post, tmp_path
444421
):

truss/truss_handle.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
1111
from urllib.error import HTTPError
1212

13-
import pkg_resources
1413
import requests
1514
import yaml
1615
from requests import exceptions
@@ -398,15 +397,11 @@ def training_docker_build_setup(self, build_dir: Optional[Path] = None):
398397
def add_python_requirement(self, python_requirement: str):
399398
"""Add a python requirement to truss model's config."""
400399

401-
# Parse the added python requirements
402-
input_python_req_name = _python_req_name(python_requirement)
403-
new_reqs = [
404-
req
405-
for req in self._spec.config.requirements
406-
if _python_req_name(req) != input_python_req_name
407-
]
408-
new_reqs.append(python_requirement)
409-
self._update_config(lambda conf: replace(conf, requirements=new_reqs))
400+
self._update_config(
401+
lambda conf: replace(
402+
conf, requirements=[*conf.requirements, python_requirement]
403+
)
404+
)
410405

411406
def remove_python_requirement(self, python_requirement: str):
412407
"""Remove a python requirement to truss model's config.
@@ -1111,8 +1106,3 @@ def _docker_image_from_labels(labels: Dict):
11111106
images = get_images(labels)
11121107
if images and isinstance(images, list):
11131108
return images[0]
1114-
1115-
1116-
def _python_req_name(python_requirement: str) -> str:
1117-
req = pkg_resources.Requirement.parse(python_requirement)
1118-
return req.name # type: ignore

0 commit comments

Comments
 (0)