Skip to content

Commit 2bd2764

Browse files
authored
Merge pull request #698 from basetenlabs/bump-version-0.7.12
Release 0.7.12
2 parents 6a65458 + 776f686 commit 2bd2764

File tree

3 files changed

+70
-39
lines changed

3 files changed

+70
-39
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.7.11"
3+
version = "0.7.12"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/remote/baseten/api.py

+44-26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import logging
2-
from typing import Optional
32

43
import requests
54
from truss.remote.baseten.auth import AuthService
@@ -59,38 +58,57 @@ def model_s3_upload_credentials(self):
5958

6059
def create_model_from_truss(
6160
self,
62-
model_name,
63-
s3_key,
64-
config,
65-
semver_bump,
66-
client_version,
67-
is_trusted=False,
68-
model_id: Optional[str] = None,
61+
model_name: str,
62+
s3_key: str,
63+
config: str,
64+
semver_bump: str,
65+
client_version: str,
66+
is_trusted: bool,
6967
):
70-
if model_id:
71-
mutation = "create_model_version_from_truss"
72-
first_arg = f'model_id: "{model_id}"'
73-
else:
74-
mutation = "create_model_from_truss"
75-
first_arg = f'name: "{model_name}"'
76-
7768
query_string = f"""
7869
mutation {{
79-
{mutation}({first_arg},
80-
s3_key: "{s3_key}",
81-
config: "{config}",
82-
semver_bump: "{semver_bump}",
83-
client_version: "{client_version}",
84-
is_trusted: {'true' if is_trusted else 'false'}
85-
) {{
86-
id,
87-
name,
88-
version_id
70+
create_model_from_truss(
71+
name: "{model_name}",
72+
s3_key: "{s3_key}",
73+
config: "{config}",
74+
semver_bump: "{semver_bump}",
75+
client_version: "{client_version}",
76+
is_trusted: {'true' if is_trusted else 'false'}
77+
) {{
78+
id,
79+
name,
80+
version_id
81+
}}
8982
}}
83+
"""
84+
resp = self._post_graphql_query(query_string)
85+
return resp["data"]["create_model_from_truss"]
86+
87+
def create_model_version_from_truss(
88+
self,
89+
model_id: str,
90+
s3_key: str,
91+
config: str,
92+
semver_bump: str,
93+
client_version: str,
94+
is_trusted: bool,
95+
):
96+
query_string = f"""
97+
mutation {{
98+
create_model_version_from_truss(
99+
model_id: "{model_id}"
100+
s3_key: "{s3_key}",
101+
config: "{config}",
102+
semver_bump: "{semver_bump}",
103+
client_version: "{client_version}",
104+
is_trusted: {'true' if is_trusted else 'false'}
105+
) {{
106+
id
107+
}}
90108
}}
91109
"""
92110
resp = self._post_graphql_query(query_string)
93-
return resp["data"][mutation]
111+
return resp["data"]["create_model_version_from_truss"]
94112

95113
def create_development_model_from_truss(
96114
self,

truss/remote/baseten/core.py

+25-12
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ def create_truss_service(
106106
model_name: str,
107107
s3_key: str,
108108
config: str,
109-
semver_bump: Optional[str] = "MINOR",
110-
is_trusted: Optional[bool] = False,
109+
semver_bump: str = "MINOR",
110+
is_trusted: bool = False,
111111
is_draft: Optional[bool] = False,
112112
model_id: Optional[str] = None,
113113
) -> Tuple[str, str]:
@@ -133,15 +133,28 @@ def create_truss_service(
133133
f"truss=={truss.version()}",
134134
is_trusted,
135135
)
136-
else:
136+
137+
return (model_version_json["id"], model_version_json["version_id"])
138+
139+
if model_id is None:
137140
model_version_json = api.create_model_from_truss(
138-
model_name,
139-
s3_key,
140-
config,
141-
semver_bump,
142-
f"truss=={truss.version()}",
143-
is_trusted,
144-
model_id,
141+
model_name=model_name,
142+
s3_key=s3_key,
143+
config=config,
144+
semver_bump=semver_bump,
145+
client_version=f"truss=={truss.version()}",
146+
is_trusted=is_trusted,
145147
)
146-
147-
return (model_version_json["id"], model_version_json["version_id"])
148+
return (model_version_json["id"], model_version_json["version_id"])
149+
150+
# Case where there is a model id already, create another version
151+
model_version_json = api.create_model_version_from_truss(
152+
model_id=model_id,
153+
s3_key=s3_key,
154+
config=config,
155+
semver_bump=semver_bump,
156+
client_version=f"truss=={truss.version()}",
157+
is_trusted=is_trusted,
158+
)
159+
model_version_id = model_version_json["id"]
160+
return (model_id, model_version_id)

0 commit comments

Comments
 (0)