diff --git a/.github/workflows/run-cluster.yaml b/.github/workflows/run-cluster.yaml index 644250d1f1..214e9d879b 100644 --- a/.github/workflows/run-cluster.yaml +++ b/.github/workflows/run-cluster.yaml @@ -123,6 +123,7 @@ jobs: --runtime-env-json "$ray_env_var" \ -- python ${{ inputs.entrypoint_script }} ${{ inputs.entrypoint_args }} - name: Download log files from ray cluster + if: always() run: | source .venv/bin/activate ray rsync-down .github/assets/ray.yaml /tmp/ray/session_*/logs ray-daft-logs @@ -152,6 +153,7 @@ jobs: source .venv/bin/activate ray down .github/assets/ray.yaml -y - name: Upload log files + if: always() uses: actions/upload-artifact@v4 with: name: ray-daft-logs diff --git a/benchmarking/ooms/big_task_heap_usage.py b/benchmarking/ooms/big_task_heap_usage.py new file mode 100644 index 0000000000..b04087ab30 --- /dev/null +++ b/benchmarking/ooms/big_task_heap_usage.py @@ -0,0 +1,71 @@ +# /// script +# dependencies = ['numpy', 'memray'] +# /// + +import argparse +import functools + +import pyarrow as pa + +import daft +from daft.io._generator import read_generator +from daft.table.table import Table + +NUM_PARTITIONS = 8 + + +@daft.udf(return_dtype=daft.DataType.binary()) +def mock_inflate_data(data, inflation_factor): + return pa.array([x * inflation_factor for x in data.to_pylist()], type=pa.large_binary()) + + +@daft.udf(return_dtype=daft.DataType.binary()) +def mock_deflate_data(data, deflation_factor): + return [x[: int(len(x) / deflation_factor)] for x in data.to_pylist()] + + +def generate(num_rows_per_partition): + yield Table.from_pydict({"foo": [b"x" for _ in range(num_rows_per_partition)]}) + + +def generator( + num_partitions: int, + num_rows_per_partition: int, +): + """Generate data for all partitions.""" + for i in range(num_partitions): + yield functools.partial(generate, num_rows_per_partition) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + "Runs a workload which is a simple map workload, but it will run 2 custom UDFs which first inflates the data, and then deflates it. " + "It starts with 1KB partitions, then runs inflation and subsequently deflation. We expect this to OOM if the heap memory usage exceeds " + "`MEM / N_CPUS` on a given worker node." + ) + parser.add_argument("--num-partitions", type=int, default=8) + parser.add_argument("--num-rows-per-partition", type=int, default=1000) + parser.add_argument("--inflation-factor", type=int, default=100) + parser.add_argument("--deflation-factor", type=int, default=100) + args = parser.parse_args() + + daft.context.set_runner_ray() + + df = read_generator( + generator(args.num_partitions, args.num_rows_per_partition), + schema=daft.Schema._from_field_name_and_types([("foo", daft.DataType.binary())]), + ) + + df.collect() + print(df) + + # Big memory explosion + df = df.with_column("foo", mock_inflate_data(df["foo"], args.inflation_factor)) + + # Big memory reduction + df = df.with_column("foo", mock_deflate_data(df["foo"], args.deflation_factor)) + + df.explain(True) + + df.collect() + print(df) diff --git a/daft/datatype.py b/daft/datatype.py index b15902c41d..a47491477f 100644 --- a/daft/datatype.py +++ b/daft/datatype.py @@ -576,38 +576,43 @@ def __hash__(self) -> int: DataTypeLike = Union[DataType, type] +import threading + _EXT_TYPE_REGISTERED = False _STATIC_DAFT_EXTENSION = None +_ext_type_lock = threading.Lock() def _ensure_registered_super_ext_type(): global _EXT_TYPE_REGISTERED global _STATIC_DAFT_EXTENSION - if not _EXT_TYPE_REGISTERED: - class DaftExtension(pa.ExtensionType): - def __init__(self, dtype, metadata=b""): - # attributes need to be set first before calling - # super init (as that calls serialize) - self._metadata = metadata - super().__init__(dtype, "daft.super_extension") + with _ext_type_lock: + if not _EXT_TYPE_REGISTERED: + + class DaftExtension(pa.ExtensionType): + def __init__(self, dtype, metadata=b""): + # attributes need to be set first before calling + # super init (as that calls serialize) + self._metadata = metadata + super().__init__(dtype, "daft.super_extension") - def __reduce__(self): - return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) + def __reduce__(self): + return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__()) - def __arrow_ext_serialize__(self): - return self._metadata + def __arrow_ext_serialize__(self): + return self._metadata - @classmethod - def __arrow_ext_deserialize__(cls, storage_type, serialized): - return cls(storage_type, serialized) + @classmethod + def __arrow_ext_deserialize__(cls, storage_type, serialized): + return cls(storage_type, serialized) - _STATIC_DAFT_EXTENSION = DaftExtension - pa.register_extension_type(DaftExtension(pa.null())) - import atexit + _STATIC_DAFT_EXTENSION = DaftExtension + pa.register_extension_type(DaftExtension(pa.null())) + import atexit - atexit.register(lambda: pa.unregister_extension_type("daft.super_extension")) - _EXT_TYPE_REGISTERED = True + atexit.register(lambda: pa.unregister_extension_type("daft.super_extension")) + _EXT_TYPE_REGISTERED = True def get_super_ext_type(): diff --git a/daft/runners/ray_metrics.py b/daft/runners/ray_metrics.py index df542446c6..31bafc871e 100644 --- a/daft/runners/ray_metrics.py +++ b/daft/runners/ray_metrics.py @@ -51,6 +51,14 @@ class EndTaskEvent(TaskEvent): # End Unix timestamp end: float + memory_stats: TaskMemoryStats + + +@dataclasses.dataclass(frozen=True) +class TaskMemoryStats: + peak_memory_allocated: int + total_memory_allocated: int + total_num_allocations: int class _NodeInfo: @@ -123,9 +131,15 @@ def mark_task_start( ) ) - def mark_task_end(self, execution_id: str, task_id: str, end: float): + def mark_task_end( + self, + execution_id: str, + task_id: str, + end: float, + memory_stats: TaskMemoryStats, + ): # Add an EndTaskEvent - self._task_events[execution_id].append(EndTaskEvent(task_id=task_id, end=end)) + self._task_events[execution_id].append(EndTaskEvent(task_id=task_id, end=end, memory_stats=memory_stats)) def get_task_events(self, execution_id: str, idx: int) -> tuple[list[TaskEvent], int]: events = self._task_events[execution_id] @@ -177,11 +191,13 @@ def mark_task_end( self, task_id: str, end: float, + memory_stats: TaskMemoryStats, ) -> None: self.actor.mark_task_end.remote( self.execution_id, task_id, end, + memory_stats, ) def get_task_events(self, idx: int) -> tuple[list[TaskEvent], int]: diff --git a/daft/runners/ray_tracing.py b/daft/runners/ray_tracing.py index b200651a76..74980bcc21 100644 --- a/daft/runners/ray_tracing.py +++ b/daft/runners/ray_tracing.py @@ -10,6 +10,7 @@ import dataclasses import json import logging +import os import pathlib import time from datetime import datetime @@ -255,6 +256,11 @@ def _flush_task_metrics(self): "ph": RunnerTracer.PHASE_ASYNC_END, "pid": 1, "tid": 2, + "args": { + "memray_peak_memory_allocated": task_event.memory_stats.peak_memory_allocated, + "memray_total_memory_allocated": task_event.memory_stats.total_memory_allocated, + "memray_total_num_allocations": task_event.memory_stats.total_num_allocations, + }, }, ts=end_ts, ) @@ -272,6 +278,11 @@ def _flush_task_metrics(self): "ph": RunnerTracer.PHASE_DURATION_END, "pid": node_idx + RunnerTracer.NODE_PIDS_START, "tid": worker_idx, + "args": { + "memray_peak_memory_allocated": task_event.memory_stats.peak_memory_allocated, + "memray_total_memory_allocated": task_event.memory_stats.total_memory_allocated, + "memray_total_num_allocations": task_event.memory_stats.total_num_allocations, + }, }, ts=end_ts, ) @@ -658,6 +669,9 @@ def collect_ray_task_metrics(execution_id: str, task_id: str, stage_id: int, exe if execution_config.enable_ray_tracing: import time + import memray + from memray._memray import compute_statistics + runtime_context = ray.get_runtime_context() metrics_actor = ray_metrics.get_metrics_actor(execution_id) @@ -670,7 +684,22 @@ def collect_ray_task_metrics(execution_id: str, task_id: str, stage_id: int, exe runtime_context.get_assigned_resources(), runtime_context.get_task_id(), ) - yield - metrics_actor.mark_task_end(task_id, time.time()) + tmpdir = "/tmp/ray/session_latest/logs/daft/task_memray_dumps" + os.makedirs(tmpdir, exist_ok=True) + memray_tmpfile = os.path.join(tmpdir, f"task-{task_id}.memray.bin") + try: + with memray.Tracker(memray_tmpfile, native_traces=True, follow_fork=True): + yield + finally: + stats = compute_statistics(memray_tmpfile) + metrics_actor.mark_task_end( + task_id, + time.time(), + ray_metrics.TaskMemoryStats( + peak_memory_allocated=stats.peak_memory_allocated, + total_memory_allocated=stats.total_memory_allocated, + total_num_allocations=stats.total_num_allocations, + ), + ) else: yield diff --git a/tests/memory/__init__.py b/tests/memory/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/memory/test_udf_project.py b/tests/memory/test_udf_project.py new file mode 100644 index 0000000000..c67cda0e28 --- /dev/null +++ b/tests/memory/test_udf_project.py @@ -0,0 +1,94 @@ +import uuid + +import pyarrow as pa +import pytest +from memray._memray import compute_statistics + +import daft +from daft.execution.execution_step import ExpressionsProjection, Project +from tests.memory.utils import run_wrapper_build_partitions + + +def format_bytes(bytes_value): + """Format bytes into human readable string with appropriate unit.""" + for unit in ["B", "KB", "MB", "GB"]: + if bytes_value < 1024: + return f"{bytes_value:.2f} {unit}" + bytes_value /= 1024 + return f"{bytes_value:.2f} GB" + + +@daft.udf(return_dtype=str) +def to_arrow_identity(s): + data = s.to_arrow() + return data + + +@daft.udf(return_dtype=str) +def to_pylist_identity(s): + data = s.to_pylist() + return data + + +@daft.udf(return_dtype=str, batch_size=128) +def to_arrow_identity_batched(s): + data = s.to_arrow() + return data + + +@daft.udf(return_dtype=str, batch_size=128) +def to_pylist_identity_batched(s): + data = s.to_pylist() + return data + + +@daft.udf(return_dtype=str, batch_size=128) +def to_pylist_identity_batched_arrow_return(s): + data = s.to_pylist() + return pa.array(data) + + +@pytest.mark.parametrize( + "udf", + [ + to_arrow_identity, + to_pylist_identity, + to_arrow_identity_batched, + to_pylist_identity_batched, + to_pylist_identity_batched_arrow_return, + ], +) +def test_short_string_identity_projection(udf): + instructions = [Project(ExpressionsProjection([udf(daft.col("a"))]))] + inputs = [{"a": [str(uuid.uuid4()) for _ in range(62500)]}] + _, memray_file = run_wrapper_build_partitions(inputs, instructions) + stats = compute_statistics(memray_file) + + expected_peak_bytes = 100 + assert stats.peak_memory_allocated < expected_peak_bytes, ( + f"Peak memory ({format_bytes(stats.peak_memory_allocated)}) " + f"exceeded threshold ({format_bytes(expected_peak_bytes)})" + ) + + +@pytest.mark.parametrize( + "udf", + [ + to_arrow_identity, + to_pylist_identity, + to_arrow_identity_batched, + to_pylist_identity_batched, + to_pylist_identity_batched_arrow_return, + ], +) +def test_long_string_identity_projection(udf): + instructions = [Project(ExpressionsProjection([udf(daft.col("a"))]))] + inputs = [{"a": [str(uuid.uuid4()) for _ in range(625000)]}] + _, memray_file = run_wrapper_build_partitions(inputs, instructions) + stats = compute_statistics(memray_file) + + expected_peak_bytes = 100 + assert stats.peak_memory_allocated < expected_peak_bytes, ( + f"Peak memory ({format_bytes(stats.peak_memory_allocated)}) " + f"exceeded threshold ({format_bytes(expected_peak_bytes)})" + ) diff --git a/tests/memory/utils.py b/tests/memory/utils.py new file mode 100644 index 0000000000..a1f17f706d --- /dev/null +++ b/tests/memory/utils.py @@ -0,0 +1,31 @@ +import logging +import os +import tempfile +import uuid +from unittest import mock + +import memray + +from daft.execution.execution_step import Instruction +from daft.runners.ray_runner import build_partitions +from daft.table import MicroPartition + +logger = logging.getLogger(__name__) + + +def run_wrapper_build_partitions( + input_partitions: list[dict], instructions: list[Instruction] +) -> tuple[list[MicroPartition], str]: + inputs = [MicroPartition.from_pydict(p) for p in input_partitions] + + logger.info("Input total size: %s", sum(i.size_bytes() for i in inputs)) + + tmpdir = tempfile.gettempdir() + memray_path = os.path.join(tmpdir, f"memray-{uuid.uuid4()}.bin") + with memray.Tracker(memray_path, native_traces=True, follow_fork=True): + results = build_partitions( + instructions, + [mock.Mock() for _ in range(len(input_partitions))], + *inputs, + ) + return results[1:], memray_path