Skip to content

Commit 24cb02e

Browse files
authored
Merge pull request #570 from basetenlabs/main
Release 0.6.0
2 parents 5d8e09c + 5f0fc96 commit 24cb02e

File tree

12 files changed

+113
-48
lines changed

12 files changed

+113
-48
lines changed

docs/mint.json

+6-1
Original file line numberDiff line numberDiff line change
@@ -179,5 +179,10 @@
179179
"name": "Issues",
180180
"url": "https://github.com/basetenlabs/truss/issues"
181181
}
182-
]
182+
],
183+
"analytics": {
184+
"gtm": {
185+
"tagId": "GTM-WXD4NQTW"
186+
}
187+
}
183188
}

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "truss"
3-
version = "0.5.9rc1"
3+
version = "0.6.0"
44
description = "A seamless bridge from model development to model delivery"
55
license = "MIT"
66
readme = "README.md"

truss/build.py

+2
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ def init(
218218
bundled_packages: Optional[List[str]] = None,
219219
trainable: bool = False,
220220
build_config: Optional[Build] = None,
221+
model_name: Optional[str] = None,
221222
) -> TrussHandle:
222223
"""
223224
Initialize an empty placeholder Truss. A Truss is a build context designed
@@ -230,6 +231,7 @@ def init(
230231
Truss in. The directory is created if it doesn't exist.
231232
"""
232233
config = TrussConfig(
234+
model_name=model_name,
233235
python_version=map_to_supported_python_version(infer_python_version()),
234236
)
235237

truss/cli/cli.py

+38-9
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,11 @@
1111
import rich
1212
import rich_click as click
1313
import truss
14-
from truss.cli.create import select_server_backend
14+
from InquirerPy import inquirer
15+
from truss.cli.create import ask_name, select_server_backend
1516
from truss.remote.remote_cli import inquire_model_name, inquire_remote_name
1617
from truss.remote.remote_factory import RemoteFactory
18+
from truss.truss_config import ModelServer
1719

1820
logging.basicConfig(level=logging.INFO)
1921

@@ -51,6 +53,8 @@ def error_handling(f: Callable[..., object]):
5153
def wrapper(*args, **kwargs):
5254
try:
5355
f(*args, **kwargs)
56+
except click.UsageError as e:
57+
raise e # You can re-raise the exception or handle it different
5458
except Exception as e:
5559
click.echo(e)
5660

@@ -93,8 +97,15 @@ def image():
9397
default=False,
9498
help="Create a trainable truss.",
9599
)
100+
@click.option(
101+
"-b",
102+
"--backend",
103+
show_default=True,
104+
default=ModelServer.TrussServer.value,
105+
type=click.Choice([server.value for server in ModelServer]),
106+
)
96107
@error_handling
97-
def init(target_directory, trainable) -> None:
108+
def init(target_directory, trainable, backend) -> None:
98109
"""Create a new truss.
99110
100111
TARGET_DIRECTORY: A Truss is created in this directory
@@ -104,13 +115,15 @@ def init(target_directory, trainable) -> None:
104115
f'Error: Directory "{target_directory}" already exists and cannot be overwritten.'
105116
)
106117
tr_path = Path(target_directory)
107-
build_config = select_server_backend()
118+
build_config = select_server_backend(ModelServer[backend])
119+
model_name = ask_name()
108120
truss.init(
109121
target_directory=target_directory,
110122
trainable=trainable,
111123
build_config=build_config,
124+
model_name=model_name,
112125
)
113-
click.echo(f"Truss was created in {tr_path}")
126+
click.echo(f"Truss {model_name} was created in {tr_path.absolute()}")
114127

115128

116129
@image.command() # type: ignore
@@ -184,10 +197,18 @@ def run(target_directory: str, build_dir: Path, tag, port, attach) -> None:
184197
required=False,
185198
help="Name of the remote in .trussrc to patch changes to",
186199
)
200+
@click.option(
201+
"--logs",
202+
is_flag=True,
203+
show_default=True,
204+
default=False,
205+
help="Automatically open remote logs tab",
206+
)
187207
@error_handling
188208
def watch(
189209
target_directory: str,
190210
remote: str,
211+
logs: bool,
191212
) -> None:
192213
"""
193214
Seamless remote development with truss
@@ -211,9 +232,13 @@ def watch(
211232
sys.exit(1)
212233

213234
logs_url = remote_provider.get_remote_logs_url(model_name) # type: ignore[attr-defined]
214-
rich.print(f"🪵 View logs for your deployment at {logs_url}")
215-
webbrowser.open(logs_url)
216-
rich.print(f"👀 Watching for changes to truss at '{target_directory}' ...")
235+
rich.print(f"🪵 View logs for your deployment at {logs_url}")
236+
if not logs:
237+
logs = inquirer.confirm(
238+
message="🗂 Open logs in a new tab?", default=True
239+
).execute()
240+
if logs:
241+
webbrowser.open_new_tab(logs_url)
217242
remote_provider.sync_truss_to_dev_version_by_name(model_name, target_directory) # type: ignore
218243

219244

@@ -273,14 +298,18 @@ def predict(
273298

274299
model_name = tr.spec.config.model_name
275300
if not model_name:
276-
raise ValueError("Model name not set. Did you `truss push`?")
301+
raise click.UsageError(
302+
"You must provide exactly one of '--data (-d)' or '--file (-f)' options."
303+
)
277304

278305
if data is not None:
279306
request_data = json.loads(data)
280307
elif file is not None:
281308
request_data = json.loads(Path(file).read_text())
282309
else:
283-
raise ValueError("At least one of request or request-file must be supplied.")
310+
raise click.UsageError(
311+
"You must provide exactly one of '--data (-d)' or '--file (-f)' options."
312+
)
284313

285314
service = remote_provider.get_baseten_service(model_name, published) # type: ignore
286315
result = service.predict(request_data)

truss/cli/create.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,7 @@
2424
}
2525

2626

27-
def select_server_backend() -> Build:
28-
server_backend = ModelServer[
29-
inquirer.select(
30-
message="Select a server:",
31-
choices=[s.value for s in ModelServer],
32-
default=ModelServer.TrussServer.value,
33-
).execute()
34-
]
27+
def select_server_backend(server_backend: ModelServer) -> Build:
3528
follow_up_questions = REQUIRED_ARGUMENTS.get(server_backend)
3629
args = {}
3730
if follow_up_questions:
@@ -41,3 +34,7 @@ def select_server_backend() -> Build:
4134
).execute()
4235

4336
return Build(model_server=server_backend, arguments=args)
37+
38+
39+
def ask_name() -> str:
40+
return inquirer.text(message="What's the name of your model?").execute()

truss/contexts/local_loader/truss_file_syncer.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from pathlib import Path
33
from threading import Thread
44

5+
import rich
6+
57

68
class TrussFilesSyncer(Thread):
79
"""Daemon thread that watches for changes in the user's Truss and syncs to running service."""
@@ -29,6 +31,10 @@ def run(self) -> None:
2931
# disable watchfiles logger
3032
logging.getLogger("watchfiles.main").disabled = True
3133

34+
rich.print(f"🚰 Attempting to sync truss at '{self.watch_path}' with remote")
35+
self.remote.patch(self.watch_path, self._logger)
36+
37+
rich.print(f"👀 Watching for changes to truss at '{self.watch_path}' ...")
3238
for _ in watch(
3339
self.watch_path, watch_filter=self.watch_filter, raise_interrupt=False
3440
):

truss/remote/baseten/remote.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33

44
import yaml
5+
from requests import ReadTimeout
56
from truss.contexts.local_loader.truss_file_syncer import TrussFilesSyncer
67
from truss.local.local_config_handler import LocalConfigHandler
78
from truss.remote.baseten.api import BasetenApi
@@ -137,6 +138,11 @@ def patch(
137138
except yaml.parser.ParserError:
138139
logger.error("Unable to parse config file")
139140
return
141+
except ValueError:
142+
logger.error(
143+
f"Error when reading truss from directory {watch_path}", exc_info=True
144+
)
145+
return
140146
model_name = truss_handle.spec.config.model_name
141147
dev_version = get_dev_version_info(self._api, model_name) # type: ignore
142148
truss_hash = dev_version.get("truss_hash", None)
@@ -145,7 +151,7 @@ def patch(
145151
try:
146152
patch_request = truss_handle.calc_patch(truss_hash)
147153
except Exception:
148-
logger.error("Failed to calculate patch")
154+
logger.error("Failed to calculate patch, bailing on patching")
149155
return
150156
if patch_request:
151157
if (
@@ -154,7 +160,16 @@ def patch(
154160
):
155161
logger.info("No changes observed, skipping deployment")
156162
return
157-
resp = self._api.patch_draft_truss(model_name, patch_request)
163+
try:
164+
resp = self._api.patch_draft_truss(model_name, patch_request)
165+
except ReadTimeout:
166+
logger.error(
167+
"Read Timeout when attempting to connect to remote. Bailing on patching"
168+
)
169+
return
170+
except Exception:
171+
logger.error("Failed to patch draft deployment, bailing on patching")
172+
return
158173
if not resp["succeeded"]:
159174
needs_full_deploy = resp.get("needs_full_deploy", None)
160175
if needs_full_deploy:
@@ -169,6 +184,6 @@ def patch(
169184
logger.info(
170185
resp.get(
171186
"success_message",
172-
f"Model {model_name} patched successfully.",
187+
f"Model {model_name} patched successfully",
173188
)
174189
)

truss/remote/remote_cli.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,11 @@
99
def inquire_remote_config() -> RemoteConfig:
1010
# TODO(bola): extract questions from remote
1111
rich.print("💻 Let's add a Baseten remote!")
12-
remote_url = inquirer.text(
13-
message="🌐 Baseten remote url:",
14-
default="https://app.baseten.co",
15-
qmark="",
16-
).execute()
12+
# If users need to adjust the remote url, they
13+
# can do so manually in the .trussrc file.
14+
remote_url = "https://app.baseten.co"
1715
api_key = inquirer.secret(
18-
message="🤫 Quiety paste your API_KEY:",
16+
message="🤫 Quietly paste your API_KEY:",
1917
qmark="",
2018
).execute()
2119
return RemoteConfig(

truss/templates/custom/model/model.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1-
from typing import Any
1+
"""
2+
The `Model` class is an interface between the ML model that you're packaging and the model
3+
server that you're running it on.
4+
5+
The main methods to implement here are:
6+
* `load`: runs exactly once when the model server is spun up or patched and loads the
7+
model onto the model server. Include any logic for initializing your model, such
8+
as downloading model weights and loading the model into memory.
9+
* `predict`: runs every time the model server is called. Include any logic for model
10+
inference and return the model output.
11+
12+
See https://truss.baseten.co/quickstart for more.
13+
"""
214

315

416
class Model:
5-
def __init__(self, **kwargs) -> None:
6-
self._data_dir = kwargs["data_dir"]
7-
self._config = kwargs["config"]
8-
self._secrets = kwargs["secrets"]
17+
def __init__(self, **kwargs):
18+
# Uncomment the following to get access
19+
# to various parts of the Truss config.
20+
21+
# self._data_dir = kwargs["data_dir"]
22+
# self._config = kwargs["config"]
23+
# self._secrets = kwargs["secrets"]
924
self._model = None
1025

1126
def load(self):
1227
# Load model here and assign to self._model.
1328
pass
1429

15-
def predict(self, model_input: Any) -> Any:
16-
model_output = {}
17-
# Invoke model on model_input and calculate predictions here.
18-
model_output["predictions"] = []
19-
return model_output
30+
def predict(self, model_input):
31+
# Run model inference here
32+
return model_input

truss/tests/test_config.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ def generate_default_config():
129129
"requirements": [],
130130
"resources": {
131131
"accelerator": None,
132-
"cpu": "500m",
133-
"memory": "512Mi",
132+
"cpu": "1",
133+
"memory": "2Gi",
134134
"use_gpu": False,
135135
},
136136
"secrets": {},
@@ -153,8 +153,8 @@ def test_default_config_not_crowded_end_to_end():
153153
requirements: []
154154
resources:
155155
accelerator: null
156-
cpu: 500m
157-
memory: 512Mi
156+
cpu: '1'
157+
memory: 2Gi
158158
use_gpu: false
159159
secrets: {}
160160
system_packages: []
@@ -192,8 +192,8 @@ def test_non_default_train():
192192
updated_train = {
193193
"resources": {
194194
"accelerator": "A10G",
195-
"cpu": "500m",
196-
"memory": "512Mi",
195+
"cpu": "1",
196+
"memory": "2Gi",
197197
"use_gpu": True,
198198
},
199199
"variables": {},

truss/tests/test_truss_handle.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -856,8 +856,8 @@ def generate_default_config():
856856
"requirements": [],
857857
"resources": {
858858
"accelerator": None,
859-
"cpu": "500m",
860-
"memory": "512Mi",
859+
"cpu": "1",
860+
"memory": "2Gi",
861861
"use_gpu": False,
862862
},
863863
"secrets": {},

truss/truss_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
DEFAULT_SPEC_VERSION = "2.0"
3030
DEFAULT_PREDICT_CONCURRENCY = 1
3131

32-
DEFAULT_CPU = "500m"
33-
DEFAULT_MEMORY = "512Mi"
32+
DEFAULT_CPU = "1"
33+
DEFAULT_MEMORY = "2Gi"
3434
DEFAULT_USE_GPU = False
3535

3636
DEFAULT_TRAINING_CLASS_FILENAME = "train.py"

0 commit comments

Comments
 (0)