Skip to content

Commit

Permalink
feat(window): implement window functions core and bindings
Browse files Browse the repository at this point in the history
  • Loading branch information
f4t4nt committed Feb 20, 2025
1 parent 0998d49 commit 71c5fdc
Show file tree
Hide file tree
Showing 7 changed files with 424 additions and 60 deletions.
100 changes: 100 additions & 0 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,106 @@ class ImageMode(Enum):
"""
...

class WindowBoundary:
"""Represents a window frame boundary in window functions."""

@staticmethod
def UnboundedPreceding() -> WindowBoundary:
"""Represents UNBOUNDED PRECEDING boundary."""
...

@staticmethod
def UnboundedFollowing() -> WindowBoundary:
"""Represents UNBOUNDED FOLLOWING boundary."""
...

@staticmethod
def CurrentRow() -> WindowBoundary:
"""Represents CURRENT ROW boundary."""
...

@staticmethod
def Preceding(n: int) -> WindowBoundary:
"""Represents N PRECEDING boundary."""
...

@staticmethod
def Following(n: int) -> WindowBoundary:
"""Represents N FOLLOWING boundary."""
...

class WindowFrameType:
"""Represents the type of window frame (ROWS or RANGE)."""

@staticmethod
def Rows() -> WindowFrameType:
"""Row-based window frame."""
...

@staticmethod
def Range() -> WindowFrameType:
"""Range-based window frame."""
...

class WindowFrame:
"""Represents a window frame specification."""

def __init__(
self,
frame_type: WindowFrameType,
start: WindowBoundary,
end: WindowBoundary,
) -> None:
"""Create a new window frame specification.
Args:
frame_type: Type of window frame (ROWS or RANGE)
start: Start boundary of window frame
end: End boundary of window frame
"""
...

class WindowSpec:
"""Represents a window specification for window functions."""

@staticmethod
def new() -> WindowSpec:
"""Create a new empty window specification."""
...

def with_partition_by(self, exprs: list[PyExpr]) -> WindowSpec:
"""Set the partition by expressions.
Args:
exprs: List of expressions to partition by
"""
...

def with_order_by(self, exprs: list[PyExpr], ascending: list[bool]) -> WindowSpec:
"""Set the order by expressions.
Args:
exprs: List of expressions to order by
ascending: List of booleans indicating sort order for each expression
"""
...

def with_frame(self, frame: WindowFrame) -> WindowSpec:
"""Set the window frame specification.
Args:
frame: Window frame specification
"""
...

def with_min_periods(self, min_periods: int) -> WindowSpec:
"""Set the minimum number of rows required to compute a result.
Args:
min_periods: Minimum number of rows required
"""
...

class ImageFormat(Enum):
"""Supported image formats for Daft's image I/O."""

Expand Down
22 changes: 10 additions & 12 deletions daft/expressions/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1459,7 +1459,7 @@ def rank(self) -> Expression:
Returns:
Expression: Expression containing rank values
"""
return Expression._from_pyexpr(native.rank(self._expr))
raise NotImplementedError("Window functions are not yet implemented")

def dense_rank(self) -> Expression:
"""Compute dense rank within window partition.
Expand All @@ -1470,7 +1470,7 @@ def dense_rank(self) -> Expression:
Returns:
Expression: Expression containing dense rank values
"""
return Expression._from_pyexpr(native.dense_rank(self._expr))
raise NotImplementedError("Window functions are not yet implemented")

def row_number(self) -> Expression:
"""Compute row number within window partition.
Expand All @@ -1481,7 +1481,7 @@ def row_number(self) -> Expression:
Returns:
Expression: Expression containing row numbers
"""
return Expression._from_pyexpr(native.row_number(self._expr))
raise NotImplementedError("Window functions are not yet implemented")

def percent_rank(self) -> Expression:
"""Compute percent rank within window partition.
Expand All @@ -1492,7 +1492,7 @@ def percent_rank(self) -> Expression:
Returns:
Expression: Expression containing percent rank values
"""
return Expression._from_pyexpr(native.percent_rank(self._expr))
raise NotImplementedError("Window functions are not yet implemented")

def ntile(self, n: int) -> Expression:
"""Divide rows in partition into n buckets numbered from 1 to n.
Expand All @@ -1506,7 +1506,7 @@ def ntile(self, n: int) -> Expression:
Returns:
Expression: Expression containing bucket numbers
"""
return Expression._from_pyexpr(native.ntile(self._expr, n))
raise NotImplementedError("Window functions are not yet implemented")

def lag(self, offset: int = 1, default: Any = None) -> Expression:
"""Access value from previous row in partition.
Expand All @@ -1518,8 +1518,7 @@ def lag(self, offset: int = 1, default: Any = None) -> Expression:
Returns:
Expression: Expression containing lagged values
"""
default_expr = Expression._to_expression(default)._expr if default is not None else None
return Expression._from_pyexpr(native.lag(self._expr, offset, default_expr))
raise NotImplementedError("Window functions are not yet implemented")

def lead(self, offset: int = 1, default: Any = None) -> Expression:
"""Access value from following row in partition.
Expand All @@ -1531,24 +1530,23 @@ def lead(self, offset: int = 1, default: Any = None) -> Expression:
Returns:
Expression: Expression containing leading values
"""
default_expr = Expression._to_expression(default)._expr if default is not None else None
return Expression._from_pyexpr(native.lead(self._expr, offset, default_expr))
raise NotImplementedError("Window functions are not yet implemented")

def first_value(self) -> Expression:
"""Get first value in window frame.
Returns:
Expression: Expression containing first values
"""
return Expression._from_pyexpr(native.first_value(self._expr))
raise NotImplementedError("Window functions are not yet implemented")

def last_value(self) -> Expression:
"""Get last value in window frame.
Returns:
Expression: Expression containing last values
"""
return Expression._from_pyexpr(native.last_value(self._expr))
raise NotImplementedError("Window functions are not yet implemented")

def nth_value(self, n: int) -> Expression:
"""Get nth value in window frame.
Expand All @@ -1559,7 +1557,7 @@ def nth_value(self, n: int) -> Expression:
Returns:
Expression: Expression containing nth values
"""
return Expression._from_pyexpr(native.nth_value(self._expr, n))
raise NotImplementedError("Window functions are not yet implemented")


SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace")
Expand Down
107 changes: 74 additions & 33 deletions daft/window.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import List, Union
from __future__ import annotations

from typing import Any

class WindowBoundary:
"""Represents window frame boundaries."""

def __init__(self, name: str):
self.name = name

def __repr__(self) -> str:
return f"WindowBoundary({self.name})"
from daft.daft import WindowBoundary as _WindowBoundary
from daft.daft import WindowFrame as _WindowFrame
from daft.daft import WindowFrameType as _WindowFrameType
from daft.daft import WindowSpec as _WindowSpec
from daft.expressions import col


class Window:
Expand All @@ -20,21 +18,15 @@ class Window:
"""

# Class-level constants for frame boundaries
unbounded_preceding = WindowBoundary("UNBOUNDED PRECEDING")
unbounded_following = WindowBoundary("UNBOUNDED FOLLOWING")
current_row = WindowBoundary("CURRENT ROW")
unbounded_preceding = _WindowBoundary.UnboundedPreceding()
unbounded_following = _WindowBoundary.UnboundedFollowing()
current_row = _WindowBoundary.CurrentRow()

def __init__(self):
self.partition_by = None
self.order_by = None
self.frame_start = self.unbounded_preceding
self.frame_end = self.unbounded_following
self._spec = _WindowSpec.new()

def __repr__(self) -> str:
return f"Window(partition_by={self.partition_by}, order_by={self.order_by}, frame_start={self.frame_start}, frame_end={self.frame_end})"

@staticmethod
def partition_by(*cols: Union[str, List[str]]) -> "Window":
@classmethod
def partition_by(cls, *cols: str | list[str]) -> Window:
"""Partitions the dataset by one or more columns.
Args:
Expand All @@ -46,9 +38,24 @@ def partition_by(*cols: Union[str, List[str]]) -> "Window":
Raises:
ValueError: If no partition columns are specified.
"""
raise NotImplementedError

def order_by(self, *cols: Union[str, List[str]], ascending: Union[bool, List[bool]] = True) -> "Window":
if not cols:
raise ValueError("At least one partition column must be specified")

# Flatten list arguments
flat_cols = []
for c in cols:
if isinstance(c, list):
flat_cols.extend(c)
else:
flat_cols.append(c)

# Create new Window with updated spec
window = cls()
window._spec = window._spec.with_partition_by([col(c)._expr for c in flat_cols])
return window

@classmethod
def order_by(cls, *cols: str | list[str], ascending: bool | list[bool] = True) -> Window:
"""Orders rows within each partition by specified columns.
Args:
Expand All @@ -58,14 +65,33 @@ def order_by(self, *cols: Union[str, List[str]], ascending: Union[bool, List[boo
Returns:
Window: A window specification with the given ordering.
"""
raise NotImplementedError
# Flatten list arguments
flat_cols = []
for c in cols:
if isinstance(c, list):
flat_cols.extend(c)
else:
flat_cols.append(c)

# Handle ascending parameter
if isinstance(ascending, bool):
asc_flags = [ascending] * len(flat_cols)
else:
if len(ascending) != len(flat_cols):
raise ValueError("Length of ascending flags must match number of order by columns")
asc_flags = ascending

# Create new Window with updated spec
window = cls()
window._spec = window._spec.with_order_by([col(c)._expr for c in flat_cols], asc_flags)
return window

def rows_between(
self,
start: Union[int, WindowBoundary] = unbounded_preceding,
end: Union[int, WindowBoundary] = unbounded_following,
start: int | Any = unbounded_preceding,
end: int | Any = unbounded_following,
min_periods: int = 1,
) -> "Window":
) -> Window:
"""Restricts each window to a row-based frame between start and end boundaries.
Args:
Expand All @@ -76,14 +102,29 @@ def rows_between(
Returns:
Window: A window specification with the given frame bounds.
"""
raise NotImplementedError
# Convert integer offsets to WindowBoundary
if isinstance(start, int):
start = _WindowBoundary.Preceding(-start) if start < 0 else _WindowBoundary.Following(start)
if isinstance(end, int):
end = _WindowBoundary.Preceding(-end) if end < 0 else _WindowBoundary.Following(end)

frame = _WindowFrame(
frame_type=_WindowFrameType.Rows(),
start=start,
end=end,
)

# Create new Window with updated spec
new_window = Window()
new_window._spec = self._spec.with_frame(frame).with_min_periods(min_periods)
return new_window

def range_between(
self,
start: Union[int, WindowBoundary] = unbounded_preceding,
end: Union[int, WindowBoundary] = unbounded_following,
start: int | Any = unbounded_preceding,
end: int | Any = unbounded_following,
min_periods: int = 1,
) -> "Window":
) -> Window:
"""Restricts each window to a range-based frame between start and end boundaries.
Args:
Expand All @@ -94,4 +135,4 @@ def range_between(
Returns:
Window: A window specification with the given frame bounds.
"""
raise NotImplementedError
raise NotImplementedError("Window.range_between is not implemented yet")
Loading

0 comments on commit 71c5fdc

Please sign in to comment.