From 3cc7f6fe07a59a65affaa3e0970b1ee54e3ffa5c Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 4 Mar 2025 09:37:13 -0800 Subject: [PATCH] chore(weave): add rich progress bar to client flush (#3828) --- tests/trace/test_weave_client.py | 43 ++++ weave/trace/client_progress_bar.py | 108 +++++++++ weave/trace/concurrent/futures.py | 5 + weave/trace/weave_client.py | 209 ++++++++++++++++-- .../async_batch_processor.py | 6 + 5 files changed, 356 insertions(+), 15 deletions(-) create mode 100644 weave/trace/client_progress_bar.py diff --git a/tests/trace/test_weave_client.py b/tests/trace/test_weave_client.py index f28b06f2b816..d3a5970f37ea 100644 --- a/tests/trace/test_weave_client.py +++ b/tests/trace/test_weave_client.py @@ -3,6 +3,7 @@ import json import platform import sys +import time import pydantic import pytest @@ -1886,3 +1887,45 @@ def my_op(a: int) -> int: # Local attributes override global ones assert call.attributes["env"] == "override" + + +def test_flush_progress_bar(client): + @weave.op + def op_1(): + time.sleep(1) + + op_1() + + # flush with progress bar + client.flush(use_progress_bar=True) + + # make sure there are no pending jobs + assert client._get_pending_jobs()["total_jobs"] == 0 + assert client._has_pending_jobs() == False + + +def test_flush_callback(client): + @weave.op + def op_1(): + time.sleep(1) + + op_1() + + def fake_logger(status): + assert "job_counts" in status + + # flush with callback + client.flush(callback=fake_logger) + + # make sure there are no pending jobs + assert client._get_pending_jobs()["total_jobs"] == 0 + assert client._has_pending_jobs() == False + + op_1() + + # this should also work, the callback will override the progress bar + client.flush(callback=fake_logger, use_progress_bar=True) + + # make sure there are no pending jobs + assert client._get_pending_jobs()["total_jobs"] == 0 + assert client._has_pending_jobs() == False diff --git a/weave/trace/client_progress_bar.py b/weave/trace/client_progress_bar.py new file mode 100644 index 000000000000..331d4bef712d --- /dev/null +++ b/weave/trace/client_progress_bar.py @@ -0,0 +1,108 @@ +"""Progress bar utilities for WeaveClient. + +This module provides functionality for displaying progress bars when flushing +tasks in the WeaveClient. +""" + +from typing import Callable + +from rich.console import Console +from rich.progress import ( + BarColumn, + Progress, + SpinnerColumn, + TaskProgressColumn, + TextColumn, + TimeElapsedColumn, +) + +from weave.trace.weave_client import FlushStatus + + +def create_progress_bar_callback() -> Callable[[FlushStatus], None]: + """Create a callback function that displays a progress bar for flush status. + + Returns: + A callback function that can be passed to WeaveClient._flush. + """ + console = Console() + + # Create a progress bar instance + progress = Progress( + SpinnerColumn(), + TextColumn("[bold blue]{task.description}"), + BarColumn(bar_width=None, complete_style="magenta"), + TaskProgressColumn(), + TimeElapsedColumn(), + console=console, + refresh_per_second=10, + expand=True, + transient=False, + ) + + # Start the progress display + progress.start() + + # Create a task for tracking progress + task_id = None + + def progress_callback(status: FlushStatus) -> None: + """Update the progress bar based on the flush status. + + Args: + status: The current flush status. + """ + nonlocal task_id + + counts = status["job_counts"] + + # If this is the first update, create the task + if task_id is None: + if counts["total_jobs"] == 0: + # No jobs to track, just return + progress.stop() + return + + # Print initial message + if not progress.live.is_started: + print(f"Flushing {counts['total_jobs']} pending tasks...") + + # Create the task + task_id = progress.add_task("Flushing tasks", total=counts["total_jobs"]) + + # If there are no more pending jobs, complete the progress bar + if not status["has_pending_jobs"]: + progress.update( + task_id, + completed=status["max_total_jobs"], + total=status["max_total_jobs"], + ) + progress.stop() + return + + # If new jobs were added, update the total + if status["max_total_jobs"] > progress.tasks[task_id].total: + progress.update(task_id, total=status["max_total_jobs"]) + + # Update progress bar with completed jobs + if status["completed_since_last_update"] > 0: + progress.update(task_id, advance=status["completed_since_last_update"]) + + # Format job details for description + job_details = [] + if counts["main_jobs"] > 0: + job_details.append(f"{counts['main_jobs']} main") + if counts["fastlane_jobs"] > 0: + job_details.append(f"{counts['fastlane_jobs']} file-upload") + if counts["call_processor_jobs"] > 0: + job_details.append(f"{counts['call_processor_jobs']} call-batch") + + job_details_str = ", ".join(job_details) if job_details else "none" + + # Update progress bar description + progress.update( + task_id, + description=f"Flushing tasks ({counts['total_jobs']} remaining: {job_details_str})", + ) + + return progress_callback diff --git a/weave/trace/concurrent/futures.py b/weave/trace/concurrent/futures.py index 96b1d610406c..aff26dad142b 100644 --- a/weave/trace/concurrent/futures.py +++ b/weave/trace/concurrent/futures.py @@ -79,6 +79,11 @@ def __init__( self._in_thread_context = ContextVar("in_deferred_context", default=False) atexit.register(self._shutdown) + @property + def num_outstanding_futures(self) -> int: + with self._active_futures_lock: + return len(self._active_futures) + def defer(self, f: Callable[..., T], *args: Any, **kwargs: Any) -> Future[T]: """ Defer a function to be executed in a thread pool. diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 6500a141151f..929555415eab 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -8,6 +8,7 @@ import platform import re import sys +import time from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache @@ -1970,34 +1971,212 @@ def _op_runs(self, op_def: Op) -> Sequence[Call]: def _ref_uri(self, name: str, version: str, path: str) -> str: return ObjectRef(self.entity, self.project, name, version).uri() - def flush(self) -> None: + def _send_file_create(self, req: FileCreateReq) -> Future[FileCreateRes]: + if self.future_executor_fastlane: + # If we have a separate upload worker pool, use it + return self.future_executor_fastlane.defer(self.server.file_create, req) + return self.future_executor.defer(self.server.file_create, req) + + @property + def num_outstanding_jobs(self) -> int: + """ + Returns the total number of pending jobs across all executors and the server. + + This property can be used to check the progress of background tasks + without blocking the main thread. + + Returns: + int: The total number of pending jobs + """ + total = self.future_executor.num_outstanding_futures + if self.future_executor_fastlane: + total += self.future_executor_fastlane.num_outstanding_futures + + # Add call batch uploads if available + if self._server_is_flushable: + total += self.server.call_processor.num_outstanding_jobs # type: ignore + return total + + def flush( + self, + use_progress_bar: bool = True, + callback: Callable[[FlushStatus], None] | None = None, + ) -> None: """ - An optional flushing method for the client. - Forces all background tasks to be processed, which ensures parallel processing - during main thread execution. Can improve performance when user code completes - before data has been uploaded to the server. + Flushes all background tasks to ensure they are processed. + + This method blocks until all currently enqueued jobs are processed, + displaying a progress bar to show the status of the pending tasks. + It ensures parallel processing during main thread execution and can + improve performance when user code completes before data has been + uploaded to the server. + + Args: + use_progress_bar: Whether to display a progress bar during flush. + Set to False for environments where a progress bar + would not render well (e.g., CI environments). + callback: Optional callback function that receives status updates. + Overrides use_progress_bar. """ + if use_progress_bar and callback is None: + from weave.trace.client_progress_bar import create_progress_bar_callback + + callback = create_progress_bar_callback() + + if callback is not None: + self._flush_with_callback(callback=callback) + else: + self._flush() + + def _flush_with_callback( + self, + callback: Callable[[FlushStatus], None], + refresh_interval: float = 0.1, + ) -> None: + """Used to wait until all currently enqueued jobs are processed. + + Args: + callback: Optional callback function that receives status updates. + refresh_interval: Time in seconds between status updates. + """ + # Initialize tracking variables + prev_job_counts = self._get_pending_jobs() + + total_completed = 0 + while self._has_pending_jobs(): + current_job_counts = self._get_pending_jobs() + + # If new jobs were added, update the total + if ( + current_job_counts["total_jobs"] + > prev_job_counts["total_jobs"] - total_completed + ): + new_jobs = current_job_counts["total_jobs"] - ( + prev_job_counts["total_jobs"] - total_completed + ) + prev_job_counts["total_jobs"] += new_jobs + + # Calculate completed jobs since last update + main_completed = max( + 0, prev_job_counts["main_jobs"] - current_job_counts["main_jobs"] + ) + fastlane_completed = max( + 0, + prev_job_counts["fastlane_jobs"] - current_job_counts["fastlane_jobs"], + ) + call_processor_completed = max( + 0, + prev_job_counts["call_processor_jobs"] + - current_job_counts["call_processor_jobs"], + ) + completed_this_iteration = ( + main_completed + fastlane_completed + call_processor_completed + ) + + if completed_this_iteration > 0: + total_completed += completed_this_iteration + + status = FlushStatus( + job_counts=current_job_counts, + completed_since_last_update=completed_this_iteration, + total_completed=total_completed, + max_total_jobs=prev_job_counts["total_jobs"], + has_pending_jobs=True, + ) + + callback(status) + + # Store current counts for next iteration + prev_job_counts = current_job_counts + + # Sleep briefly to allow background threads to make progress + time.sleep(refresh_interval) + + # Do the actual flush self._flush() + # Final callback with no pending jobs + final_status = FlushStatus( + job_counts=PendingJobCounts( + main_jobs=0, + fastlane_jobs=0, + call_processor_jobs=0, + total_jobs=0, + ), + completed_since_last_update=0, + total_completed=total_completed, + max_total_jobs=prev_job_counts["total_jobs"], + has_pending_jobs=False, + ) + callback(final_status) + def _flush(self) -> None: - # Used to wait until all currently enqueued jobs are processed + """Used to wait until all currently enqueued jobs are processed.""" if not self.future_executor._in_thread_context.get(): self.future_executor.flush() if self.future_executor_fastlane: self.future_executor_fastlane.flush() if self._server_is_flushable: - # We don't want to do an instance check here because it could - # be susceptible to shutdown race conditions. So we save a boolean - # _server_is_flushable and only call this if we know the server is - # flushable. The # type: ignore is safe because we check the type - # first. self.server.call_processor.stop_accepting_new_work_and_flush_queue() # type: ignore - def _send_file_create(self, req: FileCreateReq) -> Future[FileCreateRes]: + def _get_pending_jobs(self) -> PendingJobCounts: + """Get the current number of pending jobs for each type. + + Returns: + PendingJobCounts: + - main_jobs: Number of pending jobs in the main executor + - fastlane_jobs: Number of pending jobs in the fastlane executor + - call_processor_jobs: Number of pending jobs in the call processor + - total_jobs: Total number of pending jobs + """ + main_jobs = self.future_executor.num_outstanding_futures + fastlane_jobs = 0 if self.future_executor_fastlane: - # If we have a separate upload worker pool, use it - return self.future_executor_fastlane.defer(self.server.file_create, req) - return self.future_executor.defer(self.server.file_create, req) + fastlane_jobs = self.future_executor_fastlane.num_outstanding_futures + call_processor_jobs = 0 + if self._server_is_flushable: + call_processor_jobs = self.server.call_processor.num_outstanding_jobs # type: ignore + + return PendingJobCounts( + main_jobs=main_jobs, + fastlane_jobs=fastlane_jobs, + call_processor_jobs=call_processor_jobs, + total_jobs=main_jobs + fastlane_jobs + call_processor_jobs, + ) + + def _has_pending_jobs(self) -> bool: + """Check if there are any pending jobs. + + Returns: + True if there are pending jobs, False otherwise. + """ + return self._get_pending_jobs()["total_jobs"] > 0 + + +class PendingJobCounts(TypedDict): + """Counts of pending jobs for each type.""" + + main_jobs: int + fastlane_jobs: int + call_processor_jobs: int + total_jobs: int + + +class FlushStatus(TypedDict): + """Status information about the current flush operation.""" + + # Current job counts + job_counts: PendingJobCounts + + # Tracking of completed jobs + completed_since_last_update: int + total_completed: int + + # Maximum number of jobs seen during this flush operation + max_total_jobs: int + + # Whether there are any pending jobs + has_pending_jobs: bool def get_parallelism_settings() -> tuple[int | None, int | None]: diff --git a/weave/trace_server_bindings/async_batch_processor.py b/weave/trace_server_bindings/async_batch_processor.py index 010e627a4bc9..40ea709b6b50 100644 --- a/weave/trace_server_bindings/async_batch_processor.py +++ b/weave/trace_server_bindings/async_batch_processor.py @@ -50,6 +50,12 @@ def __init__( atexit.register(self.stop_accepting_new_work_and_flush_queue) + @property + def num_outstanding_jobs(self) -> int: + """Returns the number of items currently in the queue.""" + with self.lock: + return self.queue.qsize() + def enqueue(self, items: list[T]) -> None: """ Enqueues a list of items to be processed.