11
11
import rich
12
12
import rich_click as click
13
13
import truss
14
- from truss .cli .create import select_server_backend
14
+ from InquirerPy import inquirer
15
+ from truss .cli .create import ask_name , select_server_backend
15
16
from truss .remote .remote_cli import inquire_model_name , inquire_remote_name
16
17
from truss .remote .remote_factory import RemoteFactory
18
+ from truss .truss_config import ModelServer
17
19
18
20
logging .basicConfig (level = logging .INFO )
19
21
@@ -51,6 +53,8 @@ def error_handling(f: Callable[..., object]):
51
53
def wrapper (* args , ** kwargs ):
52
54
try :
53
55
f (* args , ** kwargs )
56
+ except click .UsageError as e :
57
+ raise e # You can re-raise the exception or handle it different
54
58
except Exception as e :
55
59
click .echo (e )
56
60
@@ -93,8 +97,15 @@ def image():
93
97
default = False ,
94
98
help = "Create a trainable truss." ,
95
99
)
100
+ @click .option (
101
+ "-b" ,
102
+ "--backend" ,
103
+ show_default = True ,
104
+ default = ModelServer .TrussServer .value ,
105
+ type = click .Choice ([server .value for server in ModelServer ]),
106
+ )
96
107
@error_handling
97
- def init (target_directory , trainable ) -> None :
108
+ def init (target_directory , trainable , backend ) -> None :
98
109
"""Create a new truss.
99
110
100
111
TARGET_DIRECTORY: A Truss is created in this directory
@@ -104,13 +115,15 @@ def init(target_directory, trainable) -> None:
104
115
f'Error: Directory "{ target_directory } " already exists and cannot be overwritten.'
105
116
)
106
117
tr_path = Path (target_directory )
107
- build_config = select_server_backend ()
118
+ build_config = select_server_backend (ModelServer [backend ])
119
+ model_name = ask_name ()
108
120
truss .init (
109
121
target_directory = target_directory ,
110
122
trainable = trainable ,
111
123
build_config = build_config ,
124
+ model_name = model_name ,
112
125
)
113
- click .echo (f"Truss was created in { tr_path } " )
126
+ click .echo (f"Truss { model_name } was created in { tr_path . absolute () } " )
114
127
115
128
116
129
@image .command () # type: ignore
@@ -184,10 +197,18 @@ def run(target_directory: str, build_dir: Path, tag, port, attach) -> None:
184
197
required = False ,
185
198
help = "Name of the remote in .trussrc to patch changes to" ,
186
199
)
200
+ @click .option (
201
+ "--logs" ,
202
+ is_flag = True ,
203
+ show_default = True ,
204
+ default = False ,
205
+ help = "Automatically open remote logs tab" ,
206
+ )
187
207
@error_handling
188
208
def watch (
189
209
target_directory : str ,
190
210
remote : str ,
211
+ logs : bool ,
191
212
) -> None :
192
213
"""
193
214
Seamless remote development with truss
@@ -211,9 +232,13 @@ def watch(
211
232
sys .exit (1 )
212
233
213
234
logs_url = remote_provider .get_remote_logs_url (model_name ) # type: ignore[attr-defined]
214
- rich .print (f"🪵 View logs for your deployment at { logs_url } " )
215
- webbrowser .open (logs_url )
216
- rich .print (f"👀 Watching for changes to truss at '{ target_directory } ' ..." )
235
+ rich .print (f"🪵 View logs for your deployment at { logs_url } " )
236
+ if not logs :
237
+ logs = inquirer .confirm (
238
+ message = "🗂 Open logs in a new tab?" , default = True
239
+ ).execute ()
240
+ if logs :
241
+ webbrowser .open_new_tab (logs_url )
217
242
remote_provider .sync_truss_to_dev_version_by_name (model_name , target_directory ) # type: ignore
218
243
219
244
@@ -273,14 +298,18 @@ def predict(
273
298
274
299
model_name = tr .spec .config .model_name
275
300
if not model_name :
276
- raise ValueError ("Model name not set. Did you `truss push`?" )
301
+ raise click .UsageError (
302
+ "You must provide exactly one of '--data (-d)' or '--file (-f)' options."
303
+ )
277
304
278
305
if data is not None :
279
306
request_data = json .loads (data )
280
307
elif file is not None :
281
308
request_data = json .loads (Path (file ).read_text ())
282
309
else :
283
- raise ValueError ("At least one of request or request-file must be supplied." )
310
+ raise click .UsageError (
311
+ "You must provide exactly one of '--data (-d)' or '--file (-f)' options."
312
+ )
284
313
285
314
service = remote_provider .get_baseten_service (model_name , published ) # type: ignore
286
315
result = service .predict (request_data )
0 commit comments