Skip to content

Commit

Permalink
feat(weave): Client & Backend support for Leaderboards (#2831)
Browse files Browse the repository at this point in the history
* Init

* gen
  • Loading branch information
tssweeney authored Oct 31, 2024
1 parent 3e1647f commit 8521472
Show file tree
Hide file tree
Showing 8 changed files with 407 additions and 2 deletions.
191 changes: 191 additions & 0 deletions tests/trace/test_leaderboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
import pytest

import weave
from weave.flow import leaderboard
from weave.trace.weave_client import get_ref


def test_leaderboard_empty(client):
evaluation_obj_1 = weave.Evaluation(
name="test_evaluation_name",
dataset=[{"input": -1, "target": -1}],
scorers=[],
)

weave.publish(evaluation_obj_1)

spec = leaderboard.Leaderboard(
name="Empty Leaderboard",
description="""This is an empty leaderboard""",
columns=[
leaderboard.LeaderboardColumn(
evaluation_object_ref=get_ref(evaluation_obj_1).uri(),
scorer_name="test_scorer_name",
summary_metric_path="test_summary_metric_path",
)
],
)

ref = weave.publish(spec)

# Overriding spec to show that this works
spec = ref.get()

results = leaderboard.get_leaderboard_results(spec, client)
assert len(results) == 0


def test_leaderboard_mis_configured(client):
spec = leaderboard.Leaderboard(
name="Misconfigured Leaderboard",
description="""This is a misconfigured leaderboard""",
columns=[
leaderboard.LeaderboardColumn(
evaluation_object_ref="test_evaluation_object_ref",
scorer_name="test_scorer_name",
summary_metric_path="test_summary_metric_path",
)
],
)

ref = weave.publish(spec)

# Overriding spec to show that this works
spec = ref.get()

results = leaderboard.get_leaderboard_results(spec, client)
assert len(results) == 0


async def do_evaluations():
@weave.op
def my_scorer(target, output):
return target == output

evaluation_obj_1 = weave.Evaluation(
name="test_evaluation_name",
dataset=[{"input": 1, "target": 1}],
scorers=[my_scorer],
)

@weave.op
def simple_model(input):
return input

await evaluation_obj_1.evaluate(simple_model)

evaluation_obj_2 = weave.Evaluation(
name="test_evaluation_name",
dataset=[{"input": 1, "target": 1}, {"input": 2, "target": 2}],
scorers=[my_scorer],
)

@weave.op
def static_model(input):
return 1

@weave.op
def bad_model(input):
return input + 1

await evaluation_obj_2.evaluate(simple_model)
await evaluation_obj_2.evaluate(static_model)
await evaluation_obj_2.evaluate(bad_model)

return evaluation_obj_1, evaluation_obj_2, simple_model, static_model, bad_model


@pytest.mark.asyncio
async def test_leaderboard_with_results(client):
(
evaluation_obj_1,
evaluation_obj_2,
simple_model,
static_model,
bad_model,
) = await do_evaluations()

spec = leaderboard.Leaderboard(
name="Simple Leaderboard",
description="""This is a simple leaderboard""",
columns=[
leaderboard.LeaderboardColumn(
evaluation_object_ref=get_ref(evaluation_obj_1).uri(),
scorer_name="my_scorer",
summary_metric_path="true_fraction",
)
],
)

ref = weave.publish(spec)

# Overriding spec to show that this works
spec = ref.get()

results = leaderboard.get_leaderboard_results(spec, client)
assert len(results) == 1
assert results[0].model_ref == get_ref(simple_model).uri()
assert results[0].column_scores[0].scores[0].value == 1.0

spec = leaderboard.Leaderboard(
name="Complex Leaderboard",
description="""
This leaderboard has multiple columns
### Columns
1. Column 1:
- Evaluation Object: test_evaluation_object_ref
- Scorer Name: test_scorer_name
- Summary Metric Path: test_summary_metric_path
2. Column 2:
- Evaluation Object: test_evaluation_object_ref
- Scorer Name: test_scorer_name
- Summary Metric Path: test_summary_metric_path
3. Column 3:
- Evaluation Object: test_evaluation_object_ref
- Scorer Name: test_scorer_name
- Summary Metric Path: test_summary_metric_path
""",
columns=[
leaderboard.LeaderboardColumn(
evaluation_object_ref=get_ref(evaluation_obj_2).uri(),
scorer_name="my_scorer",
summary_metric_path="true_count",
),
leaderboard.LeaderboardColumn(
evaluation_object_ref=get_ref(evaluation_obj_2).uri(),
scorer_name="my_scorer",
should_minimize=True,
summary_metric_path="true_fraction",
),
leaderboard.LeaderboardColumn(
evaluation_object_ref=get_ref(evaluation_obj_1).uri(),
scorer_name="my_scorer",
summary_metric_path="true_fraction",
),
],
)

ref = weave.publish(spec)

# Overriding spec to show that this works
spec = ref.get()

results = leaderboard.get_leaderboard_results(spec, client)
assert len(results) == 3
assert results[0].model_ref == get_ref(simple_model).uri()
assert len(results[0].column_scores) == 3
assert results[0].column_scores[0].scores[0].value == 2.0
assert results[0].column_scores[1].scores[0].value == 1.0
assert results[0].column_scores[1].scores[0].value == 1.0
assert results[1].model_ref == get_ref(static_model).uri()
assert len(results[1].column_scores) == 3
assert results[1].column_scores[0].scores[0].value == 1.0
assert results[1].column_scores[1].scores[0].value == 0.5
assert len(results[1].column_scores[2].scores) == 0
assert results[2].model_ref == get_ref(bad_model).uri()
assert len(results[1].column_scores) == 3
assert results[2].column_scores[0].scores[0].value == 0
assert results[2].column_scores[1].scores[0].value == 0
assert len(results[2].column_scores[2].scores) == 0
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
import * as z from 'zod';

export const LeaderboardColumnSchema = z.object({
evaluation_object_ref: z.string(),
scorer_name: z.string(),
should_minimize: z.union([z.boolean(), z.null()]).optional(),
summary_metric_path: z.string(),
});
export type LeaderboardColumn = z.infer<typeof LeaderboardColumnSchema>;

export const TestOnlyNestedBaseModelSchema = z.object({
a: z.number(),
});
Expand All @@ -16,6 +24,13 @@ export type TestOnlyNestedBaseObject = z.infer<
typeof TestOnlyNestedBaseObjectSchema
>;

export const LeaderboardSchema = z.object({
columns: z.array(LeaderboardColumnSchema),
description: z.union([z.null(), z.string()]).optional(),
name: z.union([z.null(), z.string()]).optional(),
});
export type Leaderboard = z.infer<typeof LeaderboardSchema>;

export const TestOnlyExampleSchema = z.object({
description: z.union([z.null(), z.string()]).optional(),
name: z.union([z.null(), z.string()]).optional(),
Expand All @@ -26,6 +41,7 @@ export const TestOnlyExampleSchema = z.object({
export type TestOnlyExample = z.infer<typeof TestOnlyExampleSchema>;

export const baseObjectClassRegistry = {
Leaderboard: LeaderboardSchema,
TestOnlyExample: TestOnlyExampleSchema,
TestOnlyNestedBaseObject: TestOnlyNestedBaseObjectSchema,
};
Expand Down
87 changes: 87 additions & 0 deletions weave/flow/leaderboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
from typing import Any

from pydantic import BaseModel

from weave.trace.refs import OpRef
from weave.trace.weave_client import WeaveClient, get_ref
from weave.trace_server.interface.base_object_classes import leaderboard
from weave.trace_server.trace_server_interface import CallsFilter


class LeaderboardModelEvaluationResult(BaseModel):
evaluate_call_ref: str
value: Any


class ModelScoresForColumn(BaseModel):
scores: list[LeaderboardModelEvaluationResult]


class LeaderboardModelResult(BaseModel):
model_ref: str
column_scores: list[ModelScoresForColumn]


def get_leaderboard_results(
spec: leaderboard.Leaderboard, client: WeaveClient
) -> list[LeaderboardModelResult]:
entity, project = client._project_id().split("/")
calls = client.get_calls(
filter=CallsFilter(
op_names=[
OpRef(
entity=entity,
project=project,
name="Evaluation.evaluate",
_digest="*",
).uri()
],
input_refs=[c.evaluation_object_ref for c in spec.columns],
)
)

res_map: dict[str, LeaderboardModelResult] = {}
for call in calls:
# Frustrating that we have to get the ref like this. Since the
# `Call` object auto-derefs the inputs (making a network request),
# we have to manually get the ref here... waste of network calls.
call_ref = get_ref(call)
if call_ref is None:
continue
call_ref_uri = call_ref.uri()

model_ref = get_ref(call.inputs["model"])
if model_ref is None:
continue
model_ref_uri = model_ref.uri()
if model_ref_uri not in res_map:
res_map[model_ref_uri] = LeaderboardModelResult(
model_ref=model_ref_uri,
column_scores=[ModelScoresForColumn(scores=[]) for _ in spec.columns],
)
for col_idx, c in enumerate(spec.columns):
eval_obj_ref = get_ref(call.inputs["self"])
if eval_obj_ref is None:
continue
eval_obj_ref_uri = eval_obj_ref.uri()
if c.evaluation_object_ref != eval_obj_ref_uri:
continue
val = call.output.get(c.scorer_name)
for part in c.summary_metric_path.split("."):
if isinstance(val, dict):
val = val.get(part)
elif isinstance(val, list):
val = val[int(part)]
else:
break
res_map[model_ref_uri].column_scores[col_idx].scores.append(
LeaderboardModelEvaluationResult(
evaluate_call_ref=call_ref_uri, value=val
)
)
return list(res_map.values())


# Re-export:
Leaderboard = leaderboard.Leaderboard
LeaderboardColumn = leaderboard.LeaderboardColumn
7 changes: 7 additions & 0 deletions weave/trace/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
should_disable_weave,
)
from weave.trace.table import Table
from weave.trace_server.interface.base_object_classes import leaderboard


def init(
Expand Down Expand Up @@ -109,6 +110,12 @@ def publish(obj: Any, name: Optional[str] = None) -> weave_client.ObjectRef:
ref.name,
ref.digest,
)
elif isinstance(obj, leaderboard.Leaderboard):
url = urls.leaderboard_path(
ref.entity,
ref.project,
ref.name,
)
# TODO(gst): once frontend has direct dataset/model links
# elif isinstance(obj, weave_client.Dataset):
else:
Expand Down
4 changes: 4 additions & 0 deletions weave/trace/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,9 @@ def object_version_path(
return f"{project_weave_root_url(entity_name, project_name)}/objects/{quote(object_name)}/versions/{obj_version}"


def leaderboard_path(entity_name: str, project_name: str, object_name: str) -> str:
return f"{project_weave_root_url(entity_name, project_name)}/leaderboards/{quote(object_name)}"


def redirect_call(entity_name: str, project_name: str, call_id: str) -> str:
return f"{remote_project_root_url(entity_name, project_name)}/r/call/{call_id}"
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Dict, Type

from weave.trace_server.interface.base_object_classes.base_object_def import BaseObject
from weave.trace_server.interface.base_object_classes.test_only_example import *
from weave.trace_server.interface.base_object_classes.leaderboard import Leaderboard
from weave.trace_server.interface.base_object_classes.test_only_example import (
TestOnlyExample,
TestOnlyNestedBaseObject,
)

BASE_OBJECT_REGISTRY: Dict[str, Type[BaseObject]] = {}

Expand All @@ -18,3 +22,4 @@ def register_base_object(cls: Type[BaseObject]) -> None:

register_base_object(TestOnlyExample)
register_base_object(TestOnlyNestedBaseObject)
register_base_object(Leaderboard)
Loading

0 comments on commit 8521472

Please sign in to comment.