Skip to content

Commit bce4e28

Browse files
authored
Map python versions greater than 3.9 to 3.9 (#128)
1 parent acb8da9 commit bce4e28

5 files changed

+59
-25
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.1.6rc2"
3+
version = "0.1.6"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/build.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from truss.environment_inference.requirements_inference import infer_deps
1111
from truss.errors import FrameworkNotSupportedError
1212
from truss.model_frameworks import MODEL_FRAMEWORKS_BY_TYPE, model_framework_from_model
13-
from truss.model_inference import infer_python_version
13+
from truss.model_inference import infer_python_version, map_to_supported_python_version
1414
from truss.truss_config import DEFAULT_EXAMPLES_FILENAME, TrussConfig
1515
from truss.truss_handle import TrussHandle
1616
from truss.types import ModelFrameworkType
@@ -131,7 +131,7 @@ def mk_truss_from_pipeline(
131131
config = TrussConfig(
132132
model_type="custom",
133133
model_framework=ModelFrameworkType.CUSTOM,
134-
python_version=infer_python_version(),
134+
python_version=map_to_supported_python_version(infer_python_version()),
135135
requirements=requirements,
136136
)
137137

@@ -225,7 +225,7 @@ def init(
225225
config = TrussConfig(
226226
model_type="custom",
227227
model_framework=ModelFrameworkType.CUSTOM,
228-
python_version=infer_python_version(),
228+
python_version=map_to_supported_python_version(infer_python_version()),
229229
)
230230

231231
target_directory_path = populate_target_directory(

truss/model_framework.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import yaml
66
from truss.constants import CONFIG_FILE, TEMPLATES_DIR
77
from truss.environment_inference.requirements_inference import infer_deps
8-
from truss.model_inference import infer_python_version
8+
from truss.model_inference import infer_python_version, map_to_supported_python_version
99
from truss.truss_config import DEFAULT_EXAMPLES_FILENAME, TrussConfig
1010
from truss.types import ModelFrameworkType
1111
from truss.utils import copy_file_path, copy_tree_path
@@ -58,7 +58,7 @@ def to_truss(self, model, target_directory: Path) -> str:
5858
else:
5959
target_examples_path.touch()
6060

61-
python_version = infer_python_version()
61+
python_version = map_to_supported_python_version(infer_python_version())
6262

6363
# Create config
6464
config = TrussConfig(

truss/model_inference.py

+36-19
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import logging
23
import pathlib
34
import sys
45
from ast import ClassDef, FunctionDef
@@ -25,20 +26,7 @@
2526
"py39",
2627
}
2728

28-
29-
def _get_entries_for_packages(list_of_requirements, desired_requirements):
30-
name_to_req_str = {}
31-
for req_name in desired_requirements:
32-
for req_spec_full_str in list_of_requirements:
33-
if "==" in req_spec_full_str:
34-
req_spec_name, req_version = req_spec_full_str.split("==")
35-
req_version_base = req_version.split("+")[0]
36-
if req_name == req_spec_name:
37-
name_to_req_str[req_name] = f"{req_name}=={req_version_base}"
38-
else:
39-
continue
40-
41-
return name_to_req_str
29+
logger = logging.getLogger(__name__)
4230

4331

4432
def _infer_model_framework(model_class: str):
@@ -82,11 +70,40 @@ def _model_class(model: Any):
8270

8371

8472
def infer_python_version() -> str:
85-
python_major_minor = f"py{sys.version_info.major}{sys.version_info.minor}"
86-
# might want to fix up this logic
87-
if python_major_minor not in PYTHON_VERSIONS:
88-
python_major_minor = None
89-
return python_major_minor
73+
return f"py{sys.version_info.major}{sys.version_info.minor}"
74+
75+
76+
def map_to_supported_python_version(python_version: str) -> str:
77+
"""Map python version to truss supported python version.
78+
79+
Currently, it maps any versions greater than 3.9 to 3.9.
80+
81+
Args:
82+
python_version: in the form py[major_version][minor_version] e.g. py39,
83+
py310
84+
"""
85+
python_major_version = int(python_version[2:3])
86+
python_minor_version = int(python_version[3:])
87+
88+
if python_major_version > 3:
89+
raise NotImplementedError("Only python version 3 is supported")
90+
91+
# TODO(pankaj) Add full support for 3.10 and 3.11, this is stop-gap.
92+
if python_minor_version > 9:
93+
logger.info(
94+
f"Mapping python version {python_major_version}.{python_minor_version}"
95+
" to 3.9, the highest version that Truss currently supports."
96+
)
97+
return "py39"
98+
99+
if python_minor_version < 7:
100+
logger.info(
101+
f"Mapping python version {python_major_version}.{python_minor_version}"
102+
" to 3.7, the lowest version that Truss currently supports."
103+
)
104+
return "py37"
105+
106+
return python_version
90107

91108

92109
def infer_model_information(model: Any) -> ModelBuildStageOne:

truss/tests/test_model_inference.py

+17
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from truss.constants import PYTORCH
33
from truss.model_inference import (
44
infer_model_information,
5+
map_to_supported_python_version,
56
validate_provided_parameters_with_model,
67
)
78

@@ -36,3 +37,19 @@ def test_infer_model_information(pytorch_model_with_init_args):
3637
model_info = infer_model_information(pytorch_model_with_init_args[0])
3738
assert model_info.model_framework == PYTORCH
3839
assert model_info.model_type == "MyModel"
40+
41+
42+
@pytest.mark.parametrize(
43+
"python_version, expected_python_version",
44+
[
45+
("py37", "py37"),
46+
("py38", "py38"),
47+
("py39", "py39"),
48+
("py310", "py39"),
49+
("py311", "py39"),
50+
("py36", "py37"),
51+
],
52+
)
53+
def test_map_to_supported_python_version(python_version, expected_python_version):
54+
out_python_version = map_to_supported_python_version(python_version)
55+
assert out_python_version == expected_python_version

0 commit comments

Comments
 (0)