From 71c5fdc2ca9811e2f897e463bda65c2d73c478a8 Mon Sep 17 00:00:00 2001 From: Nishant Bhakar Date: Wed, 19 Feb 2025 16:48:13 -0800 Subject: [PATCH] feat(window): implement window functions core and bindings --- daft/daft/__init__.pyi | 100 ++++++++++++++++++++ daft/expressions/expressions.py | 22 ++--- daft/window.py | 107 ++++++++++++++------- src/daft-dsl/src/expr/window.rs | 97 ++++++++++++++++--- src/daft-dsl/src/functions/mod.rs | 4 +- src/daft-dsl/src/lib.rs | 5 + src/daft-dsl/src/python.rs | 149 +++++++++++++++++++++++++++++- 7 files changed, 424 insertions(+), 60 deletions(-) diff --git a/daft/daft/__init__.pyi b/daft/daft/__init__.pyi index d59f6f0fff..8edca3d451 100644 --- a/daft/daft/__init__.pyi +++ b/daft/daft/__init__.pyi @@ -68,6 +68,106 @@ class ImageMode(Enum): """ ... +class WindowBoundary: + """Represents a window frame boundary in window functions.""" + + @staticmethod + def UnboundedPreceding() -> WindowBoundary: + """Represents UNBOUNDED PRECEDING boundary.""" + ... + + @staticmethod + def UnboundedFollowing() -> WindowBoundary: + """Represents UNBOUNDED FOLLOWING boundary.""" + ... + + @staticmethod + def CurrentRow() -> WindowBoundary: + """Represents CURRENT ROW boundary.""" + ... + + @staticmethod + def Preceding(n: int) -> WindowBoundary: + """Represents N PRECEDING boundary.""" + ... + + @staticmethod + def Following(n: int) -> WindowBoundary: + """Represents N FOLLOWING boundary.""" + ... + +class WindowFrameType: + """Represents the type of window frame (ROWS or RANGE).""" + + @staticmethod + def Rows() -> WindowFrameType: + """Row-based window frame.""" + ... + + @staticmethod + def Range() -> WindowFrameType: + """Range-based window frame.""" + ... + +class WindowFrame: + """Represents a window frame specification.""" + + def __init__( + self, + frame_type: WindowFrameType, + start: WindowBoundary, + end: WindowBoundary, + ) -> None: + """Create a new window frame specification. + + Args: + frame_type: Type of window frame (ROWS or RANGE) + start: Start boundary of window frame + end: End boundary of window frame + """ + ... + +class WindowSpec: + """Represents a window specification for window functions.""" + + @staticmethod + def new() -> WindowSpec: + """Create a new empty window specification.""" + ... + + def with_partition_by(self, exprs: list[PyExpr]) -> WindowSpec: + """Set the partition by expressions. + + Args: + exprs: List of expressions to partition by + """ + ... + + def with_order_by(self, exprs: list[PyExpr], ascending: list[bool]) -> WindowSpec: + """Set the order by expressions. + + Args: + exprs: List of expressions to order by + ascending: List of booleans indicating sort order for each expression + """ + ... + + def with_frame(self, frame: WindowFrame) -> WindowSpec: + """Set the window frame specification. + + Args: + frame: Window frame specification + """ + ... + + def with_min_periods(self, min_periods: int) -> WindowSpec: + """Set the minimum number of rows required to compute a result. + + Args: + min_periods: Minimum number of rows required + """ + ... + class ImageFormat(Enum): """Supported image formats for Daft's image I/O.""" diff --git a/daft/expressions/expressions.py b/daft/expressions/expressions.py index 6134fb124e..daa437243e 100644 --- a/daft/expressions/expressions.py +++ b/daft/expressions/expressions.py @@ -1459,7 +1459,7 @@ def rank(self) -> Expression: Returns: Expression: Expression containing rank values """ - return Expression._from_pyexpr(native.rank(self._expr)) + raise NotImplementedError("Window functions are not yet implemented") def dense_rank(self) -> Expression: """Compute dense rank within window partition. @@ -1470,7 +1470,7 @@ def dense_rank(self) -> Expression: Returns: Expression: Expression containing dense rank values """ - return Expression._from_pyexpr(native.dense_rank(self._expr)) + raise NotImplementedError("Window functions are not yet implemented") def row_number(self) -> Expression: """Compute row number within window partition. @@ -1481,7 +1481,7 @@ def row_number(self) -> Expression: Returns: Expression: Expression containing row numbers """ - return Expression._from_pyexpr(native.row_number(self._expr)) + raise NotImplementedError("Window functions are not yet implemented") def percent_rank(self) -> Expression: """Compute percent rank within window partition. @@ -1492,7 +1492,7 @@ def percent_rank(self) -> Expression: Returns: Expression: Expression containing percent rank values """ - return Expression._from_pyexpr(native.percent_rank(self._expr)) + raise NotImplementedError("Window functions are not yet implemented") def ntile(self, n: int) -> Expression: """Divide rows in partition into n buckets numbered from 1 to n. @@ -1506,7 +1506,7 @@ def ntile(self, n: int) -> Expression: Returns: Expression: Expression containing bucket numbers """ - return Expression._from_pyexpr(native.ntile(self._expr, n)) + raise NotImplementedError("Window functions are not yet implemented") def lag(self, offset: int = 1, default: Any = None) -> Expression: """Access value from previous row in partition. @@ -1518,8 +1518,7 @@ def lag(self, offset: int = 1, default: Any = None) -> Expression: Returns: Expression: Expression containing lagged values """ - default_expr = Expression._to_expression(default)._expr if default is not None else None - return Expression._from_pyexpr(native.lag(self._expr, offset, default_expr)) + raise NotImplementedError("Window functions are not yet implemented") def lead(self, offset: int = 1, default: Any = None) -> Expression: """Access value from following row in partition. @@ -1531,8 +1530,7 @@ def lead(self, offset: int = 1, default: Any = None) -> Expression: Returns: Expression: Expression containing leading values """ - default_expr = Expression._to_expression(default)._expr if default is not None else None - return Expression._from_pyexpr(native.lead(self._expr, offset, default_expr)) + raise NotImplementedError("Window functions are not yet implemented") def first_value(self) -> Expression: """Get first value in window frame. @@ -1540,7 +1538,7 @@ def first_value(self) -> Expression: Returns: Expression: Expression containing first values """ - return Expression._from_pyexpr(native.first_value(self._expr)) + raise NotImplementedError("Window functions are not yet implemented") def last_value(self) -> Expression: """Get last value in window frame. @@ -1548,7 +1546,7 @@ def last_value(self) -> Expression: Returns: Expression: Expression containing last values """ - return Expression._from_pyexpr(native.last_value(self._expr)) + raise NotImplementedError("Window functions are not yet implemented") def nth_value(self, n: int) -> Expression: """Get nth value in window frame. @@ -1559,7 +1557,7 @@ def nth_value(self, n: int) -> Expression: Returns: Expression: Expression containing nth values """ - return Expression._from_pyexpr(native.nth_value(self._expr, n)) + raise NotImplementedError("Window functions are not yet implemented") SomeExpressionNamespace = TypeVar("SomeExpressionNamespace", bound="ExpressionNamespace") diff --git a/daft/window.py b/daft/window.py index 2d7142e552..a612358051 100644 --- a/daft/window.py +++ b/daft/window.py @@ -1,14 +1,12 @@ -from typing import List, Union +from __future__ import annotations +from typing import Any -class WindowBoundary: - """Represents window frame boundaries.""" - - def __init__(self, name: str): - self.name = name - - def __repr__(self) -> str: - return f"WindowBoundary({self.name})" +from daft.daft import WindowBoundary as _WindowBoundary +from daft.daft import WindowFrame as _WindowFrame +from daft.daft import WindowFrameType as _WindowFrameType +from daft.daft import WindowSpec as _WindowSpec +from daft.expressions import col class Window: @@ -20,21 +18,15 @@ class Window: """ # Class-level constants for frame boundaries - unbounded_preceding = WindowBoundary("UNBOUNDED PRECEDING") - unbounded_following = WindowBoundary("UNBOUNDED FOLLOWING") - current_row = WindowBoundary("CURRENT ROW") + unbounded_preceding = _WindowBoundary.UnboundedPreceding() + unbounded_following = _WindowBoundary.UnboundedFollowing() + current_row = _WindowBoundary.CurrentRow() def __init__(self): - self.partition_by = None - self.order_by = None - self.frame_start = self.unbounded_preceding - self.frame_end = self.unbounded_following + self._spec = _WindowSpec.new() - def __repr__(self) -> str: - return f"Window(partition_by={self.partition_by}, order_by={self.order_by}, frame_start={self.frame_start}, frame_end={self.frame_end})" - - @staticmethod - def partition_by(*cols: Union[str, List[str]]) -> "Window": + @classmethod + def partition_by(cls, *cols: str | list[str]) -> Window: """Partitions the dataset by one or more columns. Args: @@ -46,9 +38,24 @@ def partition_by(*cols: Union[str, List[str]]) -> "Window": Raises: ValueError: If no partition columns are specified. """ - raise NotImplementedError - - def order_by(self, *cols: Union[str, List[str]], ascending: Union[bool, List[bool]] = True) -> "Window": + if not cols: + raise ValueError("At least one partition column must be specified") + + # Flatten list arguments + flat_cols = [] + for c in cols: + if isinstance(c, list): + flat_cols.extend(c) + else: + flat_cols.append(c) + + # Create new Window with updated spec + window = cls() + window._spec = window._spec.with_partition_by([col(c)._expr for c in flat_cols]) + return window + + @classmethod + def order_by(cls, *cols: str | list[str], ascending: bool | list[bool] = True) -> Window: """Orders rows within each partition by specified columns. Args: @@ -58,14 +65,33 @@ def order_by(self, *cols: Union[str, List[str]], ascending: Union[bool, List[boo Returns: Window: A window specification with the given ordering. """ - raise NotImplementedError + # Flatten list arguments + flat_cols = [] + for c in cols: + if isinstance(c, list): + flat_cols.extend(c) + else: + flat_cols.append(c) + + # Handle ascending parameter + if isinstance(ascending, bool): + asc_flags = [ascending] * len(flat_cols) + else: + if len(ascending) != len(flat_cols): + raise ValueError("Length of ascending flags must match number of order by columns") + asc_flags = ascending + + # Create new Window with updated spec + window = cls() + window._spec = window._spec.with_order_by([col(c)._expr for c in flat_cols], asc_flags) + return window def rows_between( self, - start: Union[int, WindowBoundary] = unbounded_preceding, - end: Union[int, WindowBoundary] = unbounded_following, + start: int | Any = unbounded_preceding, + end: int | Any = unbounded_following, min_periods: int = 1, - ) -> "Window": + ) -> Window: """Restricts each window to a row-based frame between start and end boundaries. Args: @@ -76,14 +102,29 @@ def rows_between( Returns: Window: A window specification with the given frame bounds. """ - raise NotImplementedError + # Convert integer offsets to WindowBoundary + if isinstance(start, int): + start = _WindowBoundary.Preceding(-start) if start < 0 else _WindowBoundary.Following(start) + if isinstance(end, int): + end = _WindowBoundary.Preceding(-end) if end < 0 else _WindowBoundary.Following(end) + + frame = _WindowFrame( + frame_type=_WindowFrameType.Rows(), + start=start, + end=end, + ) + + # Create new Window with updated spec + new_window = Window() + new_window._spec = self._spec.with_frame(frame).with_min_periods(min_periods) + return new_window def range_between( self, - start: Union[int, WindowBoundary] = unbounded_preceding, - end: Union[int, WindowBoundary] = unbounded_following, + start: int | Any = unbounded_preceding, + end: int | Any = unbounded_following, min_periods: int = 1, - ) -> "Window": + ) -> Window: """Restricts each window to a range-based frame between start and end boundaries. Args: @@ -94,4 +135,4 @@ def range_between( Returns: Window: A window specification with the given frame bounds. """ - raise NotImplementedError + raise NotImplementedError("Window.range_between is not implemented yet") diff --git a/src/daft-dsl/src/expr/window.rs b/src/daft-dsl/src/expr/window.rs index 07a4e0019b..e8cc9ccd9e 100644 --- a/src/daft-dsl/src/expr/window.rs +++ b/src/daft-dsl/src/expr/window.rs @@ -1,11 +1,16 @@ +use std::sync::Arc; + use common_error::DaftResult; -use daft_core::datatypes::DataType; +use daft_core::{datatypes::DataType, prelude::*}; +use serde::{Deserialize, Serialize}; -use crate::expr::Expr; +use crate::{ + expr::Expr, + functions::{FunctionEvaluator, FunctionExpr}, +}; /// Represents a window frame boundary -#[allow(dead_code)] -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum WindowFrameBoundary { /// Represents UNBOUNDED PRECEDING UnboundedPreceding, @@ -20,8 +25,7 @@ pub enum WindowFrameBoundary { } /// Represents the type of window frame (ROWS or RANGE) -#[allow(dead_code)] -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub enum WindowFrameType { /// Row-based window frame Rows, @@ -30,7 +34,7 @@ pub enum WindowFrameType { } /// Represents a window frame specification -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct WindowFrame { /// Type of window frame (ROWS or RANGE) pub frame_type: WindowFrameType, @@ -41,23 +45,69 @@ pub struct WindowFrame { } /// Represents a window specification -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct WindowSpec { /// Partition by expressions - pub partition_by: Vec, + pub partition_by: Vec>, /// Order by expressions - pub order_by: Vec, + pub order_by: Vec>, /// Whether each order by expression is ascending pub ascending: Vec, /// Window frame specification pub frame: Option, + /// Minimum number of rows required to compute a result + pub min_periods: i64, +} + +impl WindowSpec { + pub fn new() -> Self { + Self { + partition_by: Vec::new(), + order_by: Vec::new(), + ascending: Vec::new(), + frame: None, + min_periods: 1, + } + } + + pub fn with_partition_by(mut self, exprs: Vec>) -> Self { + self.partition_by = exprs; + self + } + + pub fn with_order_by(mut self, exprs: Vec>, ascending: Vec) -> Self { + assert_eq!( + exprs.len(), + ascending.len(), + "Order by expressions and ascending flags must have same length" + ); + self.order_by = exprs; + self.ascending = ascending; + self + } + + pub fn with_frame(mut self, frame: WindowFrame) -> Self { + self.frame = Some(frame); + self + } + + pub fn with_min_periods(mut self, min_periods: i64) -> Self { + self.min_periods = min_periods; + self + } +} + +impl Default for WindowSpec { + fn default() -> Self { + Self::new() + } } /// Represents a window function expression -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)] pub struct WindowFunction { /// The expression to apply the window function to - pub expr: Box, + pub expr: Arc, /// The window specification pub window_spec: WindowSpec, } @@ -65,7 +115,7 @@ pub struct WindowFunction { impl WindowFunction { pub fn new(expr: Expr, window_spec: WindowSpec) -> Self { Self { - expr: Box::new(expr), + expr: Arc::new(expr), window_spec, } } @@ -76,6 +126,27 @@ impl WindowFunction { } } +impl FunctionEvaluator for WindowFunction { + fn fn_name(&self) -> &'static str { + "window" + } + + fn to_field( + &self, + _inputs: &[crate::ExprRef], + schema: &Schema, + _expr: &FunctionExpr, + ) -> DaftResult { + // The output field has the same name and type as the input expression + self.expr.to_field(schema) + } + + fn evaluate(&self, _inputs: &[Series], _expr: &FunctionExpr) -> DaftResult { + // TODO: Implement window function evaluation + todo!("Implement window function evaluation") + } +} + #[allow(dead_code)] #[derive(Debug, Clone, PartialEq, Eq)] pub struct WindowExpr { diff --git a/src/daft-dsl/src/functions/mod.rs b/src/daft-dsl/src/functions/mod.rs index 4a4b9f0944..91e8e1a291 100644 --- a/src/daft-dsl/src/functions/mod.rs +++ b/src/daft-dsl/src/functions/mod.rs @@ -20,7 +20,7 @@ use serde::{Deserialize, Serialize}; pub use window::*; use self::{map::MapExpr, partitioning::PartitioningExpr, sketch::SketchExpr, struct_::StructExpr}; -use crate::ExprRef; +use crate::{expr::window::WindowFunction, ExprRef}; #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)] pub enum FunctionExpr { @@ -29,6 +29,7 @@ pub enum FunctionExpr { Struct(StructExpr), Python(PythonUDF), Partitioning(PartitioningExpr), + Window(WindowFunction), } pub trait FunctionEvaluator { @@ -51,6 +52,7 @@ impl FunctionExpr { Self::Struct(expr) => expr.get_evaluator(), Self::Python(expr) => expr, Self::Partitioning(expr) => expr.get_evaluator(), + Self::Window(expr) => expr, } } } diff --git a/src/daft-dsl/src/lib.rs b/src/daft-dsl/src/lib.rs index 6908bdf949..71baac5582 100644 --- a/src/daft-dsl/src/lib.rs +++ b/src/daft-dsl/src/lib.rs @@ -27,6 +27,11 @@ use pyo3::prelude::*; pub fn register_modules(parent: &Bound) -> PyResult<()> { parent.add_class::()?; + parent.add_class::()?; + parent.add_class::()?; + parent.add_class::()?; + parent.add_class::()?; + parent.add_function(wrap_pyfunction!(python::col, parent)?)?; parent.add_function(wrap_pyfunction!(python::lit, parent)?)?; parent.add_function(wrap_pyfunction!(python::list_, parent)?)?; diff --git a/src/daft-dsl/src/python.rs b/src/daft-dsl/src/python.rs index 6228dd635b..216c5ab044 100644 --- a/src/daft-dsl/src/python.rs +++ b/src/daft-dsl/src/python.rs @@ -21,7 +21,11 @@ use pyo3::{ }; use serde::{Deserialize, Serialize}; -use crate::{Expr, ExprRef, LiteralValue}; +use crate::{ + expr::window::{WindowFrameBoundary, WindowFunction}, + functions::FunctionExpr, + Expr, ExprRef, LiteralValue, +}; #[pyfunction] pub fn col(name: &str) -> PyResult { @@ -517,6 +521,18 @@ impl PyExpr { use crate::functions::partitioning::iceberg_truncate; Ok(iceberg_truncate(self.into(), w).into()) } + + pub fn over(&self, window_spec: &WindowSpec) -> PyResult { + Ok(Self { + expr: Arc::new(Expr::Function { + func: FunctionExpr::Window(WindowFunction::new( + (*self.expr).clone(), + window_spec.spec.clone(), + )), + inputs: vec![], + }), + }) + } } impl_bincode_py_state_serialization!(PyExpr); @@ -546,3 +562,134 @@ impl From<&PyExpr> for crate::ExprRef { item.expr.clone() } } + +#[pyclass(module = "daft.daft")] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum WindowBoundary { + UnboundedPreceding(), + UnboundedFollowing(), + CurrentRow(), + Preceding(i64), + Following(i64), +} + +impl From for WindowBoundary { + fn from(value: WindowFrameBoundary) -> Self { + match value { + WindowFrameBoundary::UnboundedPreceding => Self::UnboundedPreceding(), + WindowFrameBoundary::UnboundedFollowing => Self::UnboundedFollowing(), + WindowFrameBoundary::CurrentRow => Self::CurrentRow(), + WindowFrameBoundary::Preceding(n) => Self::Preceding(n), + WindowFrameBoundary::Following(n) => Self::Following(n), + } + } +} + +#[pyclass(module = "daft.daft")] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum WindowFrameType { + Rows(), + Range(), +} + +impl From for WindowFrameType { + fn from(value: crate::expr::window::WindowFrameType) -> Self { + match value { + crate::expr::window::WindowFrameType::Rows => Self::Rows(), + crate::expr::window::WindowFrameType::Range => Self::Range(), + } + } +} + +impl From for crate::expr::window::WindowFrameType { + fn from(value: WindowFrameType) -> Self { + match value { + WindowFrameType::Rows() => Self::Rows, + WindowFrameType::Range() => Self::Range, + } + } +} + +#[pyclass(module = "daft.daft")] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WindowFrame { + frame: crate::expr::window::WindowFrame, +} + +#[pymethods] +impl WindowFrame { + #[new] + pub fn new(frame_type: WindowFrameType, start: WindowBoundary, end: WindowBoundary) -> Self { + Self { + frame: crate::expr::window::WindowFrame { + frame_type: frame_type.into(), + start: match start { + WindowBoundary::UnboundedPreceding() => WindowFrameBoundary::UnboundedPreceding, + WindowBoundary::UnboundedFollowing() => WindowFrameBoundary::UnboundedFollowing, + WindowBoundary::CurrentRow() => WindowFrameBoundary::CurrentRow, + WindowBoundary::Preceding(n) => WindowFrameBoundary::Preceding(n), + WindowBoundary::Following(n) => WindowFrameBoundary::Following(n), + }, + end: match end { + WindowBoundary::UnboundedPreceding() => WindowFrameBoundary::UnboundedPreceding, + WindowBoundary::UnboundedFollowing() => WindowFrameBoundary::UnboundedFollowing, + WindowBoundary::CurrentRow() => WindowFrameBoundary::CurrentRow, + WindowBoundary::Preceding(n) => WindowFrameBoundary::Preceding(n), + WindowBoundary::Following(n) => WindowFrameBoundary::Following(n), + }, + }, + } + } +} + +#[pyclass(module = "daft.daft")] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WindowSpec { + spec: crate::expr::window::WindowSpec, +} + +impl Default for WindowSpec { + fn default() -> Self { + Self::new() + } +} + +#[pymethods] +impl WindowSpec { + #[staticmethod] + pub fn new() -> Self { + Self { + spec: crate::expr::window::WindowSpec::new(), + } + } + + pub fn with_partition_by(&self, exprs: Vec) -> PyResult { + Ok(Self { + spec: self + .spec + .clone() + .with_partition_by(exprs.into_iter().map(|e| e.into()).collect()), + }) + } + + pub fn with_order_by(&self, exprs: Vec, ascending: Vec) -> PyResult { + Ok(Self { + spec: self + .spec + .clone() + .with_order_by(exprs.into_iter().map(|e| e.into()).collect(), ascending), + }) + } + + pub fn with_frame(&self, frame: WindowFrame) -> PyResult { + Ok(Self { + spec: self.spec.clone().with_frame(frame.frame), + }) + } + + pub fn with_min_periods(&self, min_periods: i64) -> PyResult { + Ok(Self { + spec: self.spec.clone().with_min_periods(min_periods), + }) + } +}