Skip to content

Commit

Permalink
fix(weave): run sync eval predict functions in parallel (#3652)
Browse files Browse the repository at this point in the history
  • Loading branch information
gtarpenning authored Feb 11, 2025
1 parent bc38e39 commit dbfe828
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
38 changes: 38 additions & 0 deletions tests/trace/test_evaluate.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import time

import pytest

Expand Down Expand Up @@ -286,3 +287,40 @@ def score(self, target, model_output):

result = asyncio.run(evaluation.evaluate(model))
assert result["my-scorer"] == {"true_count": 1, "true_fraction": 0.5}


def test_sync_eval_parallelism(client):
@weave.op()
def sync_op(a):
time.sleep(1)
return a

@weave.op()
def score(output):
return 1

dataset = [
{"a": 1},
{"a": 2},
{"a": 3},
{"a": 4},
{"a": 5},
{"a": 6},
{"a": 7},
{"a": 8},
{"a": 9},
{"a": 10},
]

# 10 rows, should complete in <5 seconds. if sync, 10+

now = time.time()

evaluation = Evaluation(dataset=dataset, scorers=[score])
result = asyncio.run(evaluation.evaluate(sync_op))
assert result == {
"output": {"mean": 5.5},
"score": {"mean": 1.0},
"model_latency": {"mean": pytest.approx(1, abs=1)},
}
assert time.time() - now < 5
11 changes: 7 additions & 4 deletions weave/trace/op_caller.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ def async_call_op(
Returns:
A coroutine that will execute the Op and return a tuple of (result, Call)
"""
call_res = func.call(*args, __should_raise=True, **kwargs)
if inspect.iscoroutine(call_res):
return call_res
return asyncio.to_thread(lambda: call_res)
is_async = inspect.iscoroutinefunction(func.resolve_fn)
if is_async:
return func.call(*args, __should_raise=True, **kwargs)
else:
return asyncio.to_thread(
lambda: func.call(*args, __should_raise=True, **kwargs)
)

0 comments on commit dbfe828

Please sign in to comment.