Skip to content

Commit

Permalink
chore(weave): add rich progress bar to client flush (#3828)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning authored Mar 4, 2025
1 parent ed52f62 commit 3cc7f6f
Show file tree
Hide file tree
Showing 5 changed files with 356 additions and 15 deletions.
43 changes: 43 additions & 0 deletions tests/trace/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import platform
import sys
import time

import pydantic
import pytest
Expand Down Expand Up @@ -1886,3 +1887,45 @@ def my_op(a: int) -> int:

# Local attributes override global ones
assert call.attributes["env"] == "override"


def test_flush_progress_bar(client):
@weave.op
def op_1():
time.sleep(1)

op_1()

# flush with progress bar
client.flush(use_progress_bar=True)

# make sure there are no pending jobs
assert client._get_pending_jobs()["total_jobs"] == 0
assert client._has_pending_jobs() == False


def test_flush_callback(client):
@weave.op
def op_1():
time.sleep(1)

op_1()

def fake_logger(status):
assert "job_counts" in status

# flush with callback
client.flush(callback=fake_logger)

# make sure there are no pending jobs
assert client._get_pending_jobs()["total_jobs"] == 0
assert client._has_pending_jobs() == False

op_1()

# this should also work, the callback will override the progress bar
client.flush(callback=fake_logger, use_progress_bar=True)

# make sure there are no pending jobs
assert client._get_pending_jobs()["total_jobs"] == 0
assert client._has_pending_jobs() == False
108 changes: 108 additions & 0 deletions weave/trace/client_progress_bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Progress bar utilities for WeaveClient.
This module provides functionality for displaying progress bars when flushing
tasks in the WeaveClient.
"""

from typing import Callable

from rich.console import Console
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
)

from weave.trace.weave_client import FlushStatus


def create_progress_bar_callback() -> Callable[[FlushStatus], None]:
"""Create a callback function that displays a progress bar for flush status.
Returns:
A callback function that can be passed to WeaveClient._flush.
"""
console = Console()

# Create a progress bar instance
progress = Progress(
SpinnerColumn(),
TextColumn("[bold blue]{task.description}"),
BarColumn(bar_width=None, complete_style="magenta"),
TaskProgressColumn(),
TimeElapsedColumn(),
console=console,
refresh_per_second=10,
expand=True,
transient=False,
)

# Start the progress display
progress.start()

# Create a task for tracking progress
task_id = None

def progress_callback(status: FlushStatus) -> None:
"""Update the progress bar based on the flush status.
Args:
status: The current flush status.
"""
nonlocal task_id

counts = status["job_counts"]

# If this is the first update, create the task
if task_id is None:
if counts["total_jobs"] == 0:
# No jobs to track, just return
progress.stop()
return

# Print initial message
if not progress.live.is_started:
print(f"Flushing {counts['total_jobs']} pending tasks...")

# Create the task
task_id = progress.add_task("Flushing tasks", total=counts["total_jobs"])

# If there are no more pending jobs, complete the progress bar
if not status["has_pending_jobs"]:
progress.update(
task_id,
completed=status["max_total_jobs"],
total=status["max_total_jobs"],
)
progress.stop()
return

# If new jobs were added, update the total
if status["max_total_jobs"] > progress.tasks[task_id].total:
progress.update(task_id, total=status["max_total_jobs"])

# Update progress bar with completed jobs
if status["completed_since_last_update"] > 0:
progress.update(task_id, advance=status["completed_since_last_update"])

# Format job details for description
job_details = []
if counts["main_jobs"] > 0:
job_details.append(f"{counts['main_jobs']} main")
if counts["fastlane_jobs"] > 0:
job_details.append(f"{counts['fastlane_jobs']} file-upload")
if counts["call_processor_jobs"] > 0:
job_details.append(f"{counts['call_processor_jobs']} call-batch")

job_details_str = ", ".join(job_details) if job_details else "none"

# Update progress bar description
progress.update(
task_id,
description=f"Flushing tasks ({counts['total_jobs']} remaining: {job_details_str})",
)

return progress_callback
5 changes: 5 additions & 0 deletions weave/trace/concurrent/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def __init__(
self._in_thread_context = ContextVar("in_deferred_context", default=False)
atexit.register(self._shutdown)

@property
def num_outstanding_futures(self) -> int:
with self._active_futures_lock:
return len(self._active_futures)

def defer(self, f: Callable[..., T], *args: Any, **kwargs: Any) -> Future[T]:
"""
Defer a function to be executed in a thread pool.
Expand Down
Loading

0 comments on commit 3cc7f6f

Please sign in to comment.