-
Notifications
You must be signed in to change notification settings - Fork 80
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(weave): Client & Backend support for Leaderboards (#2831)
* Init * gen
- Loading branch information
Showing
8 changed files
with
407 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,191 @@ | ||
import pytest | ||
|
||
import weave | ||
from weave.flow import leaderboard | ||
from weave.trace.weave_client import get_ref | ||
|
||
|
||
def test_leaderboard_empty(client): | ||
evaluation_obj_1 = weave.Evaluation( | ||
name="test_evaluation_name", | ||
dataset=[{"input": -1, "target": -1}], | ||
scorers=[], | ||
) | ||
|
||
weave.publish(evaluation_obj_1) | ||
|
||
spec = leaderboard.Leaderboard( | ||
name="Empty Leaderboard", | ||
description="""This is an empty leaderboard""", | ||
columns=[ | ||
leaderboard.LeaderboardColumn( | ||
evaluation_object_ref=get_ref(evaluation_obj_1).uri(), | ||
scorer_name="test_scorer_name", | ||
summary_metric_path="test_summary_metric_path", | ||
) | ||
], | ||
) | ||
|
||
ref = weave.publish(spec) | ||
|
||
# Overriding spec to show that this works | ||
spec = ref.get() | ||
|
||
results = leaderboard.get_leaderboard_results(spec, client) | ||
assert len(results) == 0 | ||
|
||
|
||
def test_leaderboard_mis_configured(client): | ||
spec = leaderboard.Leaderboard( | ||
name="Misconfigured Leaderboard", | ||
description="""This is a misconfigured leaderboard""", | ||
columns=[ | ||
leaderboard.LeaderboardColumn( | ||
evaluation_object_ref="test_evaluation_object_ref", | ||
scorer_name="test_scorer_name", | ||
summary_metric_path="test_summary_metric_path", | ||
) | ||
], | ||
) | ||
|
||
ref = weave.publish(spec) | ||
|
||
# Overriding spec to show that this works | ||
spec = ref.get() | ||
|
||
results = leaderboard.get_leaderboard_results(spec, client) | ||
assert len(results) == 0 | ||
|
||
|
||
async def do_evaluations(): | ||
@weave.op | ||
def my_scorer(target, output): | ||
return target == output | ||
|
||
evaluation_obj_1 = weave.Evaluation( | ||
name="test_evaluation_name", | ||
dataset=[{"input": 1, "target": 1}], | ||
scorers=[my_scorer], | ||
) | ||
|
||
@weave.op | ||
def simple_model(input): | ||
return input | ||
|
||
await evaluation_obj_1.evaluate(simple_model) | ||
|
||
evaluation_obj_2 = weave.Evaluation( | ||
name="test_evaluation_name", | ||
dataset=[{"input": 1, "target": 1}, {"input": 2, "target": 2}], | ||
scorers=[my_scorer], | ||
) | ||
|
||
@weave.op | ||
def static_model(input): | ||
return 1 | ||
|
||
@weave.op | ||
def bad_model(input): | ||
return input + 1 | ||
|
||
await evaluation_obj_2.evaluate(simple_model) | ||
await evaluation_obj_2.evaluate(static_model) | ||
await evaluation_obj_2.evaluate(bad_model) | ||
|
||
return evaluation_obj_1, evaluation_obj_2, simple_model, static_model, bad_model | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_leaderboard_with_results(client): | ||
( | ||
evaluation_obj_1, | ||
evaluation_obj_2, | ||
simple_model, | ||
static_model, | ||
bad_model, | ||
) = await do_evaluations() | ||
|
||
spec = leaderboard.Leaderboard( | ||
name="Simple Leaderboard", | ||
description="""This is a simple leaderboard""", | ||
columns=[ | ||
leaderboard.LeaderboardColumn( | ||
evaluation_object_ref=get_ref(evaluation_obj_1).uri(), | ||
scorer_name="my_scorer", | ||
summary_metric_path="true_fraction", | ||
) | ||
], | ||
) | ||
|
||
ref = weave.publish(spec) | ||
|
||
# Overriding spec to show that this works | ||
spec = ref.get() | ||
|
||
results = leaderboard.get_leaderboard_results(spec, client) | ||
assert len(results) == 1 | ||
assert results[0].model_ref == get_ref(simple_model).uri() | ||
assert results[0].column_scores[0].scores[0].value == 1.0 | ||
|
||
spec = leaderboard.Leaderboard( | ||
name="Complex Leaderboard", | ||
description=""" | ||
This leaderboard has multiple columns | ||
### Columns | ||
1. Column 1: | ||
- Evaluation Object: test_evaluation_object_ref | ||
- Scorer Name: test_scorer_name | ||
- Summary Metric Path: test_summary_metric_path | ||
2. Column 2: | ||
- Evaluation Object: test_evaluation_object_ref | ||
- Scorer Name: test_scorer_name | ||
- Summary Metric Path: test_summary_metric_path | ||
3. Column 3: | ||
- Evaluation Object: test_evaluation_object_ref | ||
- Scorer Name: test_scorer_name | ||
- Summary Metric Path: test_summary_metric_path | ||
""", | ||
columns=[ | ||
leaderboard.LeaderboardColumn( | ||
evaluation_object_ref=get_ref(evaluation_obj_2).uri(), | ||
scorer_name="my_scorer", | ||
summary_metric_path="true_count", | ||
), | ||
leaderboard.LeaderboardColumn( | ||
evaluation_object_ref=get_ref(evaluation_obj_2).uri(), | ||
scorer_name="my_scorer", | ||
should_minimize=True, | ||
summary_metric_path="true_fraction", | ||
), | ||
leaderboard.LeaderboardColumn( | ||
evaluation_object_ref=get_ref(evaluation_obj_1).uri(), | ||
scorer_name="my_scorer", | ||
summary_metric_path="true_fraction", | ||
), | ||
], | ||
) | ||
|
||
ref = weave.publish(spec) | ||
|
||
# Overriding spec to show that this works | ||
spec = ref.get() | ||
|
||
results = leaderboard.get_leaderboard_results(spec, client) | ||
assert len(results) == 3 | ||
assert results[0].model_ref == get_ref(simple_model).uri() | ||
assert len(results[0].column_scores) == 3 | ||
assert results[0].column_scores[0].scores[0].value == 2.0 | ||
assert results[0].column_scores[1].scores[0].value == 1.0 | ||
assert results[0].column_scores[1].scores[0].value == 1.0 | ||
assert results[1].model_ref == get_ref(static_model).uri() | ||
assert len(results[1].column_scores) == 3 | ||
assert results[1].column_scores[0].scores[0].value == 1.0 | ||
assert results[1].column_scores[1].scores[0].value == 0.5 | ||
assert len(results[1].column_scores[2].scores) == 0 | ||
assert results[2].model_ref == get_ref(bad_model).uri() | ||
assert len(results[1].column_scores) == 3 | ||
assert results[2].column_scores[0].scores[0].value == 0 | ||
assert results[2].column_scores[1].scores[0].value == 0 | ||
assert len(results[2].column_scores[2].scores) == 0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
from typing import Any | ||
|
||
from pydantic import BaseModel | ||
|
||
from weave.trace.refs import OpRef | ||
from weave.trace.weave_client import WeaveClient, get_ref | ||
from weave.trace_server.interface.base_object_classes import leaderboard | ||
from weave.trace_server.trace_server_interface import CallsFilter | ||
|
||
|
||
class LeaderboardModelEvaluationResult(BaseModel): | ||
evaluate_call_ref: str | ||
value: Any | ||
|
||
|
||
class ModelScoresForColumn(BaseModel): | ||
scores: list[LeaderboardModelEvaluationResult] | ||
|
||
|
||
class LeaderboardModelResult(BaseModel): | ||
model_ref: str | ||
column_scores: list[ModelScoresForColumn] | ||
|
||
|
||
def get_leaderboard_results( | ||
spec: leaderboard.Leaderboard, client: WeaveClient | ||
) -> list[LeaderboardModelResult]: | ||
entity, project = client._project_id().split("/") | ||
calls = client.get_calls( | ||
filter=CallsFilter( | ||
op_names=[ | ||
OpRef( | ||
entity=entity, | ||
project=project, | ||
name="Evaluation.evaluate", | ||
_digest="*", | ||
).uri() | ||
], | ||
input_refs=[c.evaluation_object_ref for c in spec.columns], | ||
) | ||
) | ||
|
||
res_map: dict[str, LeaderboardModelResult] = {} | ||
for call in calls: | ||
# Frustrating that we have to get the ref like this. Since the | ||
# `Call` object auto-derefs the inputs (making a network request), | ||
# we have to manually get the ref here... waste of network calls. | ||
call_ref = get_ref(call) | ||
if call_ref is None: | ||
continue | ||
call_ref_uri = call_ref.uri() | ||
|
||
model_ref = get_ref(call.inputs["model"]) | ||
if model_ref is None: | ||
continue | ||
model_ref_uri = model_ref.uri() | ||
if model_ref_uri not in res_map: | ||
res_map[model_ref_uri] = LeaderboardModelResult( | ||
model_ref=model_ref_uri, | ||
column_scores=[ModelScoresForColumn(scores=[]) for _ in spec.columns], | ||
) | ||
for col_idx, c in enumerate(spec.columns): | ||
eval_obj_ref = get_ref(call.inputs["self"]) | ||
if eval_obj_ref is None: | ||
continue | ||
eval_obj_ref_uri = eval_obj_ref.uri() | ||
if c.evaluation_object_ref != eval_obj_ref_uri: | ||
continue | ||
val = call.output.get(c.scorer_name) | ||
for part in c.summary_metric_path.split("."): | ||
if isinstance(val, dict): | ||
val = val.get(part) | ||
elif isinstance(val, list): | ||
val = val[int(part)] | ||
else: | ||
break | ||
res_map[model_ref_uri].column_scores[col_idx].scores.append( | ||
LeaderboardModelEvaluationResult( | ||
evaluate_call_ref=call_ref_uri, value=val | ||
) | ||
) | ||
return list(res_map.values()) | ||
|
||
|
||
# Re-export: | ||
Leaderboard = leaderboard.Leaderboard | ||
LeaderboardColumn = leaderboard.LeaderboardColumn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.