Skip to content

Commit 56608ff

Browse files
authored
Fix patch behavior for trusses using python DX (#1337)
1 parent 8037619 commit 56608ff

File tree

2 files changed

+149
-30
lines changed

2 files changed

+149
-30
lines changed

truss-chains/truss_chains/deployment/deployment_client.py

+133-27
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from truss.remote import remote_factory
2828
from truss.remote.baseten import core as b10_core
2929
from truss.remote.baseten import custom_types as b10_types
30+
from truss.remote.baseten import error as b10_errors
3031
from truss.remote.baseten import remote as b10_remote
3132
from truss.remote.baseten import service as b10_service
3233
from truss.truss_handle import truss_handle
@@ -495,6 +496,114 @@ def _create_chains_secret_if_missing(remote_provider: b10_remote.BasetenRemote)
495496
# Watch / Live Patching ################################################################
496497

497498

499+
def _create_watch_filter(root_dir: pathlib.Path):
500+
ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(root_dir)
501+
502+
def watch_filter(_: watchfiles.Change, path: str) -> bool:
503+
return not truss_path.is_ignored(pathlib.Path(path), ignore_patterns)
504+
505+
logging.getLogger("watchfiles.main").disabled = True
506+
return ignore_patterns, watch_filter
507+
508+
509+
def _handle_intercepted_logs(logs: list[str], console: "rich_console.Console"):
510+
if logs:
511+
formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
512+
console.print(f"Intercepted logs from importing source code:\n{formatted_logs}")
513+
514+
515+
def _handle_import_error(
516+
exception: Exception,
517+
console: "rich_console.Console",
518+
error_console: "rich_console.Console",
519+
stack_trace: Optional[str] = None,
520+
):
521+
error_console.print(
522+
"Source files were changed, but pre-conditions for "
523+
"live patching are not given. Most likely there is a "
524+
"syntax error in the source files or names changed. "
525+
"Try to fix the issue and save the file. Error:\n"
526+
f"{textwrap.indent(str(exception), ' ' * 4)}"
527+
)
528+
if stack_trace:
529+
error_console.print(stack_trace)
530+
531+
console.print(
532+
"The watcher will continue and if you can resolve the "
533+
"issue, subsequent patches might succeed.",
534+
style="blue",
535+
)
536+
537+
538+
class _ModelWatcher:
539+
_source: pathlib.Path
540+
_model_name: str
541+
_remote_provider: b10_remote.BasetenRemote
542+
_ignore_patterns: list[str]
543+
_watch_filter: Callable[[watchfiles.Change, str], bool]
544+
_console: "rich_console.Console"
545+
_error_console: "rich_console.Console"
546+
547+
def __init__(
548+
self,
549+
source: pathlib.Path,
550+
model_name: str,
551+
remote_provider: b10_remote.BasetenRemote,
552+
console: "rich_console.Console",
553+
error_console: "rich_console.Console",
554+
) -> None:
555+
self._source = source
556+
self._model_name = model_name
557+
self._remote_provider = remote_provider
558+
self._console = console
559+
self._error_console = error_console
560+
self._ignore_patterns, self._watch_filter = _create_watch_filter(
561+
source.absolute().parent
562+
)
563+
564+
dev_version = b10_core.get_dev_version(self._remote_provider.api, model_name)
565+
if not dev_version:
566+
raise b10_errors.RemoteError(
567+
"No development model found. Run `truss push` then try again."
568+
)
569+
570+
def _patch(self) -> None:
571+
exception_raised = None
572+
with log_utils.LogInterceptor() as log_interceptor, self._console.status(
573+
" Live Patching Model.\n", spinner="arrow3"
574+
):
575+
try:
576+
gen_truss_path = code_gen.gen_truss_model_from_source(self._source)
577+
return self._remote_provider.patch(
578+
gen_truss_path,
579+
self._ignore_patterns,
580+
self._console,
581+
self._error_console,
582+
)
583+
except Exception as e:
584+
exception_raised = e
585+
finally:
586+
logs = log_interceptor.get_logs()
587+
588+
_handle_intercepted_logs(logs, self._console)
589+
if exception_raised:
590+
_handle_import_error(exception_raised, self._console, self._error_console)
591+
592+
def watch(self) -> None:
593+
# Perform one initial patch at startup.
594+
self._patch()
595+
self._console.print("👀 Watching for new changes.", style="blue")
596+
597+
# TODO(nikhil): Improve detection of directory structure, since right now
598+
# we assume a flat structure
599+
root_dir = self._source.absolute().parent
600+
for _ in watchfiles.watch(
601+
root_dir, watch_filter=self._watch_filter, raise_interrupt=False
602+
):
603+
self._patch()
604+
self._console.print("👀 Watching for new changes.", style="blue")
605+
606+
498607
class _Watcher:
499608
_source: pathlib.Path
500609
_entrypoint: Optional[str]
@@ -573,16 +682,10 @@ def __init__(
573682

574683
self._chainlet_data = {c.name: c for c in deployed_chainlets}
575684
self._assert_chainlet_names_same(chainlet_names)
576-
self._ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(
685+
self._ignore_patterns, self._watch_filter = _create_watch_filter(
577686
self._chain_root
578687
)
579688

580-
def watch_filter(_: watchfiles.Change, path: str) -> bool:
581-
return not truss_path.is_ignored(pathlib.Path(path), self._ignore_patterns)
582-
583-
logging.getLogger("watchfiles.main").disabled = True
584-
self._watch_filter = watch_filter
585-
586689
@property
587690
def _original_chainlet_names(self) -> set[str]:
588691
return set(self._chainlet_data.keys())
@@ -665,27 +768,13 @@ def _patch(self, executor: concurrent.futures.Executor) -> None:
665768
finally:
666769
logs = log_interceptor.get_logs()
667770

668-
if logs:
669-
formatted_logs = textwrap.indent("\n".join(logs), " " * 4)
670-
self._console.print(
671-
f"Intercepted logs from importing chain source code:\n{formatted_logs}"
672-
)
673-
771+
_handle_intercepted_logs(logs, self._console)
674772
if exception_raised:
675-
self._error_console.print(
676-
"Source files were changed, but pre-conditions for "
677-
"live patching are not given. Most likely there is a "
678-
"syntax in the source files or chainlet names changed. "
679-
"Try to fix the issue and save the file. Error:\n"
680-
f"{textwrap.indent(str(exception_raised), ' ' * 4)}"
681-
)
682-
if self._show_stack_trace:
683-
self._error_console.print(stack_trace)
684-
685-
self._console.print(
686-
"The watcher will continue and if you can resolve the "
687-
"issue, subsequent patches might succeed.",
688-
style="blue",
773+
_handle_import_error(
774+
exception_raised,
775+
self._console,
776+
self._error_console,
777+
stack_trace=stack_trace if self._show_stack_trace else None,
689778
)
690779
return
691780

@@ -775,3 +864,20 @@ def watch(
775864
included_chainlets,
776865
)
777866
patcher.watch()
867+
868+
869+
def watch_model(
870+
source: pathlib.Path,
871+
model_name: str,
872+
remote_provider: b10_remote.TrussRemote,
873+
console: "rich_console.Console",
874+
error_console: "rich_console.Console",
875+
):
876+
patcher = _ModelWatcher(
877+
source=source,
878+
model_name=model_name,
879+
remote_provider=cast(b10_remote.BasetenRemote, remote_provider),
880+
console=console,
881+
error_console=error_console,
882+
)
883+
patcher.watch()

truss/cli/cli.py

+16-3
Original file line numberDiff line numberDiff line change
@@ -397,9 +397,22 @@ def watch(target_directory: str, remote: str) -> None:
397397
console.print(
398398
f"🪵 View logs for your deployment at {_format_link(service.logs_url)}"
399399
)
400-
remote_provider.sync_truss_to_dev_version_by_name(
401-
model_name, target_directory, console, error_console
402-
)
400+
401+
if not os.path.isfile(target_directory):
402+
remote_provider.sync_truss_to_dev_version_by_name(
403+
model_name, target_directory, console, error_console
404+
)
405+
else:
406+
# These imports are delayed, to handle pydantic v1 envs gracefully.
407+
from truss_chains.deployment import deployment_client
408+
409+
deployment_client.watch_model(
410+
source=Path(target_directory),
411+
model_name=model_name,
412+
remote_provider=remote_provider,
413+
console=console,
414+
error_console=error_console,
415+
)
403416

404417

405418
# Chains Stuff #########################################################################

0 commit comments

Comments
 (0)