From dbfe8286a8ea242820c058bb620b49f36f905c33 Mon Sep 17 00:00:00 2001 From: Griffin Tarpenning Date: Tue, 11 Feb 2025 10:37:50 -0800 Subject: [PATCH] fix(weave): run sync eval predict functions in parallel (#3652) --- tests/trace/test_evaluate.py | 38 ++++++++++++++++++++++++++++++++++++ weave/trace/op_caller.py | 11 +++++++---- 2 files changed, 45 insertions(+), 4 deletions(-) diff --git a/tests/trace/test_evaluate.py b/tests/trace/test_evaluate.py index 002ed34fee3a..3a51cee41256 100644 --- a/tests/trace/test_evaluate.py +++ b/tests/trace/test_evaluate.py @@ -1,4 +1,5 @@ import asyncio +import time import pytest @@ -286,3 +287,40 @@ def score(self, target, model_output): result = asyncio.run(evaluation.evaluate(model)) assert result["my-scorer"] == {"true_count": 1, "true_fraction": 0.5} + + +def test_sync_eval_parallelism(client): + @weave.op() + def sync_op(a): + time.sleep(1) + return a + + @weave.op() + def score(output): + return 1 + + dataset = [ + {"a": 1}, + {"a": 2}, + {"a": 3}, + {"a": 4}, + {"a": 5}, + {"a": 6}, + {"a": 7}, + {"a": 8}, + {"a": 9}, + {"a": 10}, + ] + + # 10 rows, should complete in <5 seconds. if sync, 10+ + + now = time.time() + + evaluation = Evaluation(dataset=dataset, scorers=[score]) + result = asyncio.run(evaluation.evaluate(sync_op)) + assert result == { + "output": {"mean": 5.5}, + "score": {"mean": 1.0}, + "model_latency": {"mean": pytest.approx(1, abs=1)}, + } + assert time.time() - now < 5 diff --git a/weave/trace/op_caller.py b/weave/trace/op_caller.py index 4903df0768a3..b1ee09787896 100644 --- a/weave/trace/op_caller.py +++ b/weave/trace/op_caller.py @@ -44,7 +44,10 @@ def async_call_op( Returns: A coroutine that will execute the Op and return a tuple of (result, Call) """ - call_res = func.call(*args, __should_raise=True, **kwargs) - if inspect.iscoroutine(call_res): - return call_res - return asyncio.to_thread(lambda: call_res) + is_async = inspect.iscoroutinefunction(func.resolve_fn) + if is_async: + return func.call(*args, __should_raise=True, **kwargs) + else: + return asyncio.to_thread( + lambda: func.call(*args, __should_raise=True, **kwargs) + )