|
27 | 27 | from truss.remote import remote_factory
|
28 | 28 | from truss.remote.baseten import core as b10_core
|
29 | 29 | from truss.remote.baseten import custom_types as b10_types
|
| 30 | +from truss.remote.baseten import error as b10_errors |
30 | 31 | from truss.remote.baseten import remote as b10_remote
|
31 | 32 | from truss.remote.baseten import service as b10_service
|
32 | 33 | from truss.truss_handle import truss_handle
|
@@ -495,6 +496,114 @@ def _create_chains_secret_if_missing(remote_provider: b10_remote.BasetenRemote)
|
495 | 496 | # Watch / Live Patching ################################################################
|
496 | 497 |
|
497 | 498 |
|
| 499 | +def _create_watch_filter(root_dir: pathlib.Path): |
| 500 | + ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir(root_dir) |
| 501 | + |
| 502 | + def watch_filter(_: watchfiles.Change, path: str) -> bool: |
| 503 | + return not truss_path.is_ignored(pathlib.Path(path), ignore_patterns) |
| 504 | + |
| 505 | + logging.getLogger("watchfiles.main").disabled = True |
| 506 | + return ignore_patterns, watch_filter |
| 507 | + |
| 508 | + |
| 509 | +def _handle_intercepted_logs(logs: list[str], console: "rich_console.Console"): |
| 510 | + if logs: |
| 511 | + formatted_logs = textwrap.indent("\n".join(logs), " " * 4) |
| 512 | + console.print(f"Intercepted logs from importing source code:\n{formatted_logs}") |
| 513 | + |
| 514 | + |
| 515 | +def _handle_import_error( |
| 516 | + exception: Exception, |
| 517 | + console: "rich_console.Console", |
| 518 | + error_console: "rich_console.Console", |
| 519 | + stack_trace: Optional[str] = None, |
| 520 | +): |
| 521 | + error_console.print( |
| 522 | + "Source files were changed, but pre-conditions for " |
| 523 | + "live patching are not given. Most likely there is a " |
| 524 | + "syntax error in the source files or names changed. " |
| 525 | + "Try to fix the issue and save the file. Error:\n" |
| 526 | + f"{textwrap.indent(str(exception), ' ' * 4)}" |
| 527 | + ) |
| 528 | + if stack_trace: |
| 529 | + error_console.print(stack_trace) |
| 530 | + |
| 531 | + console.print( |
| 532 | + "The watcher will continue and if you can resolve the " |
| 533 | + "issue, subsequent patches might succeed.", |
| 534 | + style="blue", |
| 535 | + ) |
| 536 | + |
| 537 | + |
| 538 | +class _ModelWatcher: |
| 539 | + _source: pathlib.Path |
| 540 | + _model_name: str |
| 541 | + _remote_provider: b10_remote.BasetenRemote |
| 542 | + _ignore_patterns: list[str] |
| 543 | + _watch_filter: Callable[[watchfiles.Change, str], bool] |
| 544 | + _console: "rich_console.Console" |
| 545 | + _error_console: "rich_console.Console" |
| 546 | + |
| 547 | + def __init__( |
| 548 | + self, |
| 549 | + source: pathlib.Path, |
| 550 | + model_name: str, |
| 551 | + remote_provider: b10_remote.BasetenRemote, |
| 552 | + console: "rich_console.Console", |
| 553 | + error_console: "rich_console.Console", |
| 554 | + ) -> None: |
| 555 | + self._source = source |
| 556 | + self._model_name = model_name |
| 557 | + self._remote_provider = remote_provider |
| 558 | + self._console = console |
| 559 | + self._error_console = error_console |
| 560 | + self._ignore_patterns, self._watch_filter = _create_watch_filter( |
| 561 | + source.absolute().parent |
| 562 | + ) |
| 563 | + |
| 564 | + dev_version = b10_core.get_dev_version(self._remote_provider.api, model_name) |
| 565 | + if not dev_version: |
| 566 | + raise b10_errors.RemoteError( |
| 567 | + "No development model found. Run `truss push` then try again." |
| 568 | + ) |
| 569 | + |
| 570 | + def _patch(self) -> None: |
| 571 | + exception_raised = None |
| 572 | + with log_utils.LogInterceptor() as log_interceptor, self._console.status( |
| 573 | + " Live Patching Model.\n", spinner="arrow3" |
| 574 | + ): |
| 575 | + try: |
| 576 | + gen_truss_path = code_gen.gen_truss_model_from_source(self._source) |
| 577 | + return self._remote_provider.patch( |
| 578 | + gen_truss_path, |
| 579 | + self._ignore_patterns, |
| 580 | + self._console, |
| 581 | + self._error_console, |
| 582 | + ) |
| 583 | + except Exception as e: |
| 584 | + exception_raised = e |
| 585 | + finally: |
| 586 | + logs = log_interceptor.get_logs() |
| 587 | + |
| 588 | + _handle_intercepted_logs(logs, self._console) |
| 589 | + if exception_raised: |
| 590 | + _handle_import_error(exception_raised, self._console, self._error_console) |
| 591 | + |
| 592 | + def watch(self) -> None: |
| 593 | + # Perform one initial patch at startup. |
| 594 | + self._patch() |
| 595 | + self._console.print("👀 Watching for new changes.", style="blue") |
| 596 | + |
| 597 | + # TODO(nikhil): Improve detection of directory structure, since right now |
| 598 | + # we assume a flat structure |
| 599 | + root_dir = self._source.absolute().parent |
| 600 | + for _ in watchfiles.watch( |
| 601 | + root_dir, watch_filter=self._watch_filter, raise_interrupt=False |
| 602 | + ): |
| 603 | + self._patch() |
| 604 | + self._console.print("👀 Watching for new changes.", style="blue") |
| 605 | + |
| 606 | + |
498 | 607 | class _Watcher:
|
499 | 608 | _source: pathlib.Path
|
500 | 609 | _entrypoint: Optional[str]
|
@@ -573,16 +682,10 @@ def __init__(
|
573 | 682 |
|
574 | 683 | self._chainlet_data = {c.name: c for c in deployed_chainlets}
|
575 | 684 | self._assert_chainlet_names_same(chainlet_names)
|
576 |
| - self._ignore_patterns = truss_path.load_trussignore_patterns_from_truss_dir( |
| 685 | + self._ignore_patterns, self._watch_filter = _create_watch_filter( |
577 | 686 | self._chain_root
|
578 | 687 | )
|
579 | 688 |
|
580 |
| - def watch_filter(_: watchfiles.Change, path: str) -> bool: |
581 |
| - return not truss_path.is_ignored(pathlib.Path(path), self._ignore_patterns) |
582 |
| - |
583 |
| - logging.getLogger("watchfiles.main").disabled = True |
584 |
| - self._watch_filter = watch_filter |
585 |
| - |
586 | 689 | @property
|
587 | 690 | def _original_chainlet_names(self) -> set[str]:
|
588 | 691 | return set(self._chainlet_data.keys())
|
@@ -665,27 +768,13 @@ def _patch(self, executor: concurrent.futures.Executor) -> None:
|
665 | 768 | finally:
|
666 | 769 | logs = log_interceptor.get_logs()
|
667 | 770 |
|
668 |
| - if logs: |
669 |
| - formatted_logs = textwrap.indent("\n".join(logs), " " * 4) |
670 |
| - self._console.print( |
671 |
| - f"Intercepted logs from importing chain source code:\n{formatted_logs}" |
672 |
| - ) |
673 |
| - |
| 771 | + _handle_intercepted_logs(logs, self._console) |
674 | 772 | if exception_raised:
|
675 |
| - self._error_console.print( |
676 |
| - "Source files were changed, but pre-conditions for " |
677 |
| - "live patching are not given. Most likely there is a " |
678 |
| - "syntax in the source files or chainlet names changed. " |
679 |
| - "Try to fix the issue and save the file. Error:\n" |
680 |
| - f"{textwrap.indent(str(exception_raised), ' ' * 4)}" |
681 |
| - ) |
682 |
| - if self._show_stack_trace: |
683 |
| - self._error_console.print(stack_trace) |
684 |
| - |
685 |
| - self._console.print( |
686 |
| - "The watcher will continue and if you can resolve the " |
687 |
| - "issue, subsequent patches might succeed.", |
688 |
| - style="blue", |
| 773 | + _handle_import_error( |
| 774 | + exception_raised, |
| 775 | + self._console, |
| 776 | + self._error_console, |
| 777 | + stack_trace=stack_trace if self._show_stack_trace else None, |
689 | 778 | )
|
690 | 779 | return
|
691 | 780 |
|
@@ -775,3 +864,20 @@ def watch(
|
775 | 864 | included_chainlets,
|
776 | 865 | )
|
777 | 866 | patcher.watch()
|
| 867 | + |
| 868 | + |
| 869 | +def watch_model( |
| 870 | + source: pathlib.Path, |
| 871 | + model_name: str, |
| 872 | + remote_provider: b10_remote.TrussRemote, |
| 873 | + console: "rich_console.Console", |
| 874 | + error_console: "rich_console.Console", |
| 875 | +): |
| 876 | + patcher = _ModelWatcher( |
| 877 | + source=source, |
| 878 | + model_name=model_name, |
| 879 | + remote_provider=cast(b10_remote.BasetenRemote, remote_provider), |
| 880 | + console=console, |
| 881 | + error_console=error_console, |
| 882 | + ) |
| 883 | + patcher.watch() |
0 commit comments