Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(weave): add rich progress bar to client flush #3828

Merged
merged 11 commits into from
Mar 4, 2025
43 changes: 43 additions & 0 deletions tests/trace/test_weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import platform
import sys
import time

import pydantic
import pytest
Expand Down Expand Up @@ -1886,3 +1887,45 @@ def my_op(a: int) -> int:

# Local attributes override global ones
assert call.attributes["env"] == "override"


def test_flush_progress_bar(client):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not testing the pbar.

If you really want to test the pbar, you can capture redirect IO into a buf and confirm it looks the way you want

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm yeah, the issue is actually more generating a test that actually requires flushing. not sure there is a trivial way of doing it... scratching my head over here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few ways, but maybe the easiest is to use a mock server that is slow to respond / only accepts 1 connection. Then you can queue up some jobs and see them go through 1-by-1

@weave.op
def op_1():
time.sleep(1)

op_1()

# flush with progress bar
client.flush(use_progress_bar=True)

# make sure there are no pending jobs
assert client._get_pending_jobs()["total_jobs"] == 0
assert client._has_pending_jobs() == False


def test_flush_callback(client):
@weave.op
def op_1():
time.sleep(1)

op_1()

def fake_logger(status):
assert "job_counts" in status

# flush with callback
client.flush(callback=fake_logger)

# make sure there are no pending jobs
assert client._get_pending_jobs()["total_jobs"] == 0
assert client._has_pending_jobs() == False

op_1()

# this should also work, the callback will override the progress bar
client.flush(callback=fake_logger, use_progress_bar=True)

# make sure there are no pending jobs
assert client._get_pending_jobs()["total_jobs"] == 0
assert client._has_pending_jobs() == False
108 changes: 108 additions & 0 deletions weave/trace/client_progress_bar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Progress bar utilities for WeaveClient.

This module provides functionality for displaying progress bars when flushing
tasks in the WeaveClient.
"""

from typing import Callable

from rich.console import Console
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
)

from weave.trace.weave_client import FlushStatus


def create_progress_bar_callback() -> Callable[[FlushStatus], None]:
"""Create a callback function that displays a progress bar for flush status.

Returns:
A callback function that can be passed to WeaveClient._flush.
"""
console = Console()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you check that it doesn't overwrite user logs? I'm not sure why but rich sometimes paints over my terminal which can be annoying

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its fine? How exactly can I test that?
Screenshot 2025-03-03 at 7 03 43 PM

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The trace donut link gets printed while the bar is going, I think its fine.


# Create a progress bar instance
progress = Progress(
SpinnerColumn(),
TextColumn("[bold blue]{task.description}"),
BarColumn(bar_width=None, complete_style="magenta"),
TaskProgressColumn(),
TimeElapsedColumn(),
console=console,
refresh_per_second=10,
expand=True,
transient=False,
)

# Start the progress display
progress.start()

# Create a task for tracking progress
task_id = None

def progress_callback(status: FlushStatus) -> None:
"""Update the progress bar based on the flush status.

Args:
status: The current flush status.
"""
nonlocal task_id

counts = status["job_counts"]

# If this is the first update, create the task
if task_id is None:
if counts["total_jobs"] == 0:
# No jobs to track, just return
progress.stop()
return

# Print initial message
if not progress.live.is_started:
print(f"Flushing {counts['total_jobs']} pending tasks...")

# Create the task
task_id = progress.add_task("Flushing tasks", total=counts["total_jobs"])

# If there are no more pending jobs, complete the progress bar
if not status["has_pending_jobs"]:
progress.update(
task_id,
completed=status["max_total_jobs"],
total=status["max_total_jobs"],
)
progress.stop()
return

# If new jobs were added, update the total
if status["max_total_jobs"] > progress.tasks[task_id].total:
progress.update(task_id, total=status["max_total_jobs"])

# Update progress bar with completed jobs
if status["completed_since_last_update"] > 0:
progress.update(task_id, advance=status["completed_since_last_update"])

# Format job details for description
job_details = []
if counts["main_jobs"] > 0:
job_details.append(f"{counts['main_jobs']} main")
if counts["fastlane_jobs"] > 0:
job_details.append(f"{counts['fastlane_jobs']} file-upload")
if counts["call_processor_jobs"] > 0:
job_details.append(f"{counts['call_processor_jobs']} call-batch")

job_details_str = ", ".join(job_details) if job_details else "none"

# Update progress bar description
progress.update(
task_id,
description=f"Flushing tasks ({counts['total_jobs']} remaining: {job_details_str})",
)

return progress_callback
5 changes: 5 additions & 0 deletions weave/trace/concurrent/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def __init__(
self._in_thread_context = ContextVar("in_deferred_context", default=False)
atexit.register(self._shutdown)

@property
def num_outstanding_futures(self) -> int:
with self._active_futures_lock:
return len(self._active_futures)

def defer(self, f: Callable[..., T], *args: Any, **kwargs: Any) -> Future[T]:
"""
Defer a function to be executed in a thread pool.
Expand Down
Loading