diff --git a/tests/trace/test_dataset.py b/tests/trace/test_dataset.py index 3054bef1a4c0..02ab2f0e48a9 100644 --- a/tests/trace/test_dataset.py +++ b/tests/trace/test_dataset.py @@ -1,6 +1,7 @@ import pytest import weave +from weave.trace.context.tests_context import raise_on_captured_errors def test_basic_dataset_lifecycle(client): @@ -54,3 +55,13 @@ def greet(name: str, age: int) -> str: assert rows[1]["inputs"]["name"] == "Bob" assert rows[1]["inputs"]["age"] == 25 assert rows[1]["output"] == "Hello Bob, you are 25!" + + +def test_dataset_caching(client): + ds = weave.Dataset(rows=[{"a": i} for i in range(200)]) + ref = weave.publish(ds) + + ds2 = ref.get() + + with raise_on_captured_errors(): + assert len(ds2) == 200 diff --git a/weave/trace/vals.py b/weave/trace/vals.py index 58f1515c370f..21014eb2deb4 100644 --- a/weave/trace/vals.py +++ b/weave/trace/vals.py @@ -397,15 +397,27 @@ def _remote_iter(self) -> Generator[dict, None, None]: ) ) - if self._prefetched_rows is not None and len(response.rows) != len( - self._prefetched_rows - ): - if get_raise_on_captured_errors(): - raise - logger.error( - f"Expected length of response rows ({len(response.rows)}) to match prefetched rows ({len(self._prefetched_rows)}). Ignoring prefetched rows." - ) - self._prefetched_rows = None + # When paginating through large datasets, we need special handling for prefetched rows + # on the first page. This is because prefetched_rows contains ALL rows, while each + # response page contains at most page_size rows. + if page_index == 0 and self._prefetched_rows is not None: + response_rows_len = len(response.rows) + prefetched_rows_len = len(self._prefetched_rows) + + # There are two valid scenarios: + # 1. The response rows exactly match prefetched rows (small dataset, no pagination needed) + # 2. We're paginating a large dataset (response has page_size rows, prefetched has more) + # + # Any other mismatch indicates an inconsistency that should be handled by + # discarding the prefetched rows and relying solely on server responses. + if response_rows_len != prefetched_rows_len and not ( + response_rows_len == page_size and prefetched_rows_len > page_size + ): + msg = f"Expected length of response rows ({response_rows_len}) to match prefetched rows ({prefetched_rows_len}). Ignoring prefetched rows." + if get_raise_on_captured_errors(): + raise ValueError(msg) + logger.debug(msg) + self._prefetched_rows = None for i, item in enumerate(response.rows): new_ref = self.ref.with_item(item.digest) if self.ref else None @@ -418,7 +430,7 @@ def _remote_iter(self) -> Generator[dict, None, None]: val = ( item.val if self._prefetched_rows is None - else self._prefetched_rows[i] + else self._prefetched_rows[page_index * page_size + i] ) res = from_json(val, self.table_ref.project_id, self.server) res = make_trace_obj(res, new_ref, self.server, self.root)