Skip to content

Commit

Permalink
fix(weave): Fix table ref get caching issue (#3764)
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong authored Feb 28, 2025
1 parent 2caadfe commit 1d702d4
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 10 deletions.
11 changes: 11 additions & 0 deletions tests/trace/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

import weave
from weave.trace.context.tests_context import raise_on_captured_errors


def test_basic_dataset_lifecycle(client):
Expand Down Expand Up @@ -54,3 +55,13 @@ def greet(name: str, age: int) -> str:
assert rows[1]["inputs"]["name"] == "Bob"
assert rows[1]["inputs"]["age"] == 25
assert rows[1]["output"] == "Hello Bob, you are 25!"


def test_dataset_caching(client):
ds = weave.Dataset(rows=[{"a": i} for i in range(200)])
ref = weave.publish(ds)

ds2 = ref.get()

with raise_on_captured_errors():
assert len(ds2) == 200
32 changes: 22 additions & 10 deletions weave/trace/vals.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,15 +397,27 @@ def _remote_iter(self) -> Generator[dict, None, None]:
)
)

if self._prefetched_rows is not None and len(response.rows) != len(
self._prefetched_rows
):
if get_raise_on_captured_errors():
raise
logger.error(
f"Expected length of response rows ({len(response.rows)}) to match prefetched rows ({len(self._prefetched_rows)}). Ignoring prefetched rows."
)
self._prefetched_rows = None
# When paginating through large datasets, we need special handling for prefetched rows
# on the first page. This is because prefetched_rows contains ALL rows, while each
# response page contains at most page_size rows.
if page_index == 0 and self._prefetched_rows is not None:
response_rows_len = len(response.rows)
prefetched_rows_len = len(self._prefetched_rows)

# There are two valid scenarios:
# 1. The response rows exactly match prefetched rows (small dataset, no pagination needed)
# 2. We're paginating a large dataset (response has page_size rows, prefetched has more)
#
# Any other mismatch indicates an inconsistency that should be handled by
# discarding the prefetched rows and relying solely on server responses.
if response_rows_len != prefetched_rows_len and not (
response_rows_len == page_size and prefetched_rows_len > page_size
):
msg = f"Expected length of response rows ({response_rows_len}) to match prefetched rows ({prefetched_rows_len}). Ignoring prefetched rows."
if get_raise_on_captured_errors():
raise ValueError(msg)
logger.debug(msg)
self._prefetched_rows = None

for i, item in enumerate(response.rows):
new_ref = self.ref.with_item(item.digest) if self.ref else None
Expand All @@ -418,7 +430,7 @@ def _remote_iter(self) -> Generator[dict, None, None]:
val = (
item.val
if self._prefetched_rows is None
else self._prefetched_rows[i]
else self._prefetched_rows[page_index * page_size + i]
)
res = from_json(val, self.table_ref.project_id, self.server)
res = make_trace_obj(res, new_ref, self.server, self.root)
Expand Down

0 comments on commit 1d702d4

Please sign in to comment.