diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 32fa45528d3a..c28b1d5d329f 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -369,7 +369,9 @@ impl Expr { expr_arena: &mut Arena, ) -> PolarsResult { let root = to_aexpr(self.clone(), expr_arena)?; - expr_arena.get(root).to_field(schema, ctxt, expr_arena) + expr_arena + .get(root) + .to_field_and_validate(schema, ctxt, expr_arena) } } diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 6547c391eaae..3b6226cc8ade 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -15,6 +15,13 @@ fn float_type(field: &mut Field) { } } +fn validate_expr(node: Node, arena: &Arena, schema: &Schema) -> PolarsResult<()> { + arena + .get(node) + .to_field_impl(schema, Context::Default, arena, &mut false, true) + .map(|_| ()) +} + impl AExpr { pub fn to_dtype( &self, @@ -36,7 +43,27 @@ impl AExpr { // in an aggregation context, so functions that return scalars should explicitly set this // to false in `to_field_impl`. let mut agg_list = matches!(ctx, Context::Aggregation); - let mut field = self.to_field_impl(schema, ctx, arena, &mut agg_list)?; + let mut field = self.to_field_impl(schema, ctx, arena, &mut agg_list, false)?; + + if agg_list { + field.coerce(field.dtype().clone().implode()); + } + + Ok(field) + } + + /// Get Field result of the expression. The schema is the input data. + pub fn to_field_and_validate( + &self, + schema: &Schema, + ctx: Context, + arena: &Arena, + ) -> PolarsResult { + // Indicates whether we should auto-implode the result. This is initialized to true if we are + // in an aggregation context, so functions that return scalars should explicitly set this + // to false in `to_field_impl`. + let mut agg_list = matches!(ctx, Context::Aggregation); + let mut field = self.to_field_impl(schema, ctx, arena, &mut agg_list, true)?; if agg_list { field.coerce(field.dtype().clone().implode()); @@ -56,6 +83,8 @@ impl AExpr { ctx: Context, arena: &Arena, agg_list: &mut bool, + // Traverse all expressions to validate they are in the schema. + validate: bool, ) -> PolarsResult { use AExpr::*; use DataType::*; @@ -65,7 +94,10 @@ impl AExpr { Ok(Field::new(PlSmallStr::from_static(LEN), IDX_DTYPE)) }, Window { - function, options, .. + function, + options, + partition_by, + order_by, } => { if let WindowType::Over(WindowMapping::Join) = options { // expr.over(..), defaults to agg-list unless explicitly unset @@ -73,8 +105,17 @@ impl AExpr { *agg_list = true; } + if validate { + for node in partition_by { + validate_expr(*node, arena, schema)?; + } + if let Some((node, _)) = order_by { + validate_expr(*node, arena, schema)?; + } + } + let e = arena.get(*function); - e.to_field_impl(schema, ctx, arena, agg_list) + e.to_field_impl(schema, ctx, arena, agg_list, validate) }, Explode(expr) => { // `Explode` is a "flatten" operation, which is not the same as returning a scalar. @@ -82,7 +123,7 @@ impl AExpr { // the `agg_list` state here. let field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; let field = match field.dtype() { List(inner) => Field::new(field.name().clone(), *inner.clone()), @@ -97,7 +138,7 @@ impl AExpr { name.clone(), arena .get(*expr) - .to_field_impl(schema, ctx, arena, agg_list)? + .to_field_impl(schema, ctx, arena, agg_list, validate)? .dtype, )), Column(name) => schema @@ -128,40 +169,55 @@ impl AExpr { let out_name = { out_field = arena .get(*left) - .to_field_impl(schema, ctx, arena, agg_list)?; + .to_field_impl(schema, ctx, arena, agg_list, validate)?; out_field.name() }; Field::new(out_name.clone(), Boolean) }, Operator::TrueDivide => { - return get_truediv_field(*left, *right, arena, ctx, schema, agg_list) + return get_truediv_field( + *left, *right, arena, ctx, schema, agg_list, validate, + ) }, _ => { return get_arithmetic_field( - *left, *right, arena, *op, ctx, schema, agg_list, + *left, *right, arena, *op, ctx, schema, agg_list, validate, ) }, }; Ok(field) }, - Sort { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, agg_list), + Sort { expr, .. } => arena + .get(*expr) + .to_field_impl(schema, ctx, arena, agg_list, validate), Gather { expr, + idx, returns_scalar, .. } => { if *returns_scalar { *agg_list = false; } + if validate { + validate_expr(*idx, arena, schema)? + } arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false) + .to_field_impl(schema, ctx, arena, &mut false, validate) + }, + SortBy { expr, .. } => arena + .get(*expr) + .to_field_impl(schema, ctx, arena, agg_list, validate), + Filter { input, by } => { + if validate { + validate_expr(*by, arena, schema)? + } + arena + .get(*input) + .to_field_impl(schema, ctx, arena, agg_list, validate) }, - SortBy { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, agg_list), - Filter { input, .. } => arena - .get(*input) - .to_field_impl(schema, ctx, arena, agg_list), Agg(agg) => { use IRAggExpr::*; match agg { @@ -172,13 +228,13 @@ impl AExpr { *agg_list = false; arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false) + .to_field_impl(schema, ctx, arena, &mut false, validate) }, Sum(expr) => { *agg_list = false; let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; let dt = match field.dtype() { Boolean => Some(IDX_DTYPE), UInt8 | Int8 | Int16 | UInt16 => Some(Int64), @@ -193,7 +249,7 @@ impl AExpr { *agg_list = false; let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; match field.dtype { Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)), _ => float_type(&mut field), @@ -204,7 +260,7 @@ impl AExpr { *agg_list = false; let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; match field.dtype { Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)), _ => float_type(&mut field), @@ -214,7 +270,7 @@ impl AExpr { Implode(expr) => { let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; field.coerce(DataType::List(field.dtype().clone().into())); Ok(field) }, @@ -222,7 +278,7 @@ impl AExpr { *agg_list = false; let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; float_type(&mut field); Ok(field) }, @@ -230,7 +286,7 @@ impl AExpr { *agg_list = false; let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; float_type(&mut field); Ok(field) }, @@ -238,7 +294,7 @@ impl AExpr { *agg_list = false; let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; field.coerce(IDX_DTYPE); Ok(field) }, @@ -246,7 +302,7 @@ impl AExpr { *agg_list = false; let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; field.coerce(IDX_DTYPE); Ok(field) }, @@ -254,7 +310,7 @@ impl AExpr { *agg_list = true; let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; field.coerce(List(IDX_DTYPE.into())); Ok(field) }, @@ -262,7 +318,7 @@ impl AExpr { *agg_list = false; let mut field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; float_type(&mut field); Ok(field) }, @@ -271,7 +327,7 @@ impl AExpr { *agg_list = false; let field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, &mut false)?; + .to_field_impl(schema, ctx, arena, &mut false, validate)?; // @Q? Do we need to coerce here? Ok(field) }, @@ -280,7 +336,7 @@ impl AExpr { Cast { expr, dtype, .. } => { let field = arena .get(*expr) - .to_field_impl(schema, ctx, arena, agg_list)?; + .to_field_impl(schema, ctx, arena, agg_list, validate)?; Ok(Field::new(field.name().clone(), dtype.clone())) }, Ternary { truthy, falsy, .. } => { @@ -291,14 +347,20 @@ impl AExpr { // left: col(foo): list nesting: 1 // right; col(foo).first(): T nesting: 0 // col(foo) + col(foo).first() will have nesting 1 as we still maintain the groups list. - let mut truthy = - arena - .get(*truthy) - .to_field_impl(schema, ctx, arena, &mut agg_list_truthy)?; - let falsy = - arena - .get(*falsy) - .to_field_impl(schema, ctx, arena, &mut agg_list_falsy)?; + let mut truthy = arena.get(*truthy).to_field_impl( + schema, + ctx, + arena, + &mut agg_list_truthy, + validate, + )?; + let falsy = arena.get(*falsy).to_field_impl( + schema, + ctx, + arena, + &mut agg_list_falsy, + validate, + )?; let st = if let DataType::Null = *truthy.dtype() { falsy.dtype().clone() @@ -317,7 +379,7 @@ impl AExpr { options, .. } => { - let fields = func_args_to_fields(input, ctx, schema, arena, agg_list)?; + let fields = func_args_to_fields(input, ctx, schema, arena, agg_list, validate)?; polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str); let out = output_type.get_field(schema, ctx, &fields)?; @@ -334,7 +396,7 @@ impl AExpr { input, options, } => { - let fields = func_args_to_fields(input, ctx, schema, arena, agg_list)?; + let fields = func_args_to_fields(input, ctx, schema, arena, agg_list, validate)?; polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function); let out = function.get_field(schema, ctx, &fields)?; @@ -346,9 +408,20 @@ impl AExpr { Ok(out) }, - Slice { input, .. } => arena - .get(*input) - .to_field_impl(schema, ctx, arena, agg_list), + Slice { + input, + offset, + length, + } => { + if validate { + validate_expr(*offset, arena, schema)?; + validate_expr(*length, arena, schema)?; + } + + arena + .get(*input) + .to_field_impl(schema, ctx, arena, agg_list, validate) + }, } } } @@ -359,6 +432,7 @@ fn func_args_to_fields( schema: &Schema, arena: &Arena, agg_list: &mut bool, + validate: bool, ) -> PolarsResult> { input .iter() @@ -379,6 +453,7 @@ fn func_args_to_fields( } else { tmp }, + validate, ) .map(|mut field| { field.name = e.output_name().clone(); @@ -388,6 +463,7 @@ fn func_args_to_fields( .collect() } +#[allow(clippy::too_many_arguments)] fn get_arithmetic_field( left: Node, right: Node, @@ -396,6 +472,7 @@ fn get_arithmetic_field( ctx: Context, schema: &Schema, agg_list: &mut bool, + validate: bool, ) -> PolarsResult { use DataType::*; let left_ae = arena.get(left); @@ -409,11 +486,13 @@ fn get_arithmetic_field( // leading to quadratic behavior. # 4736 // // further right_type is only determined when needed. - let mut left_field = left_ae.to_field_impl(schema, ctx, arena, agg_list)?; + let mut left_field = left_ae.to_field_impl(schema, ctx, arena, agg_list, validate)?; let super_type = match op { Operator::Minus => { - let right_type = right_ae.to_field_impl(schema, ctx, arena, agg_list)?.dtype; + let right_type = right_ae + .to_field_impl(schema, ctx, arena, agg_list, validate)? + .dtype; match (&left_field.dtype, &right_type) { #[cfg(feature = "dtype-struct")] (Struct(_), Struct(_)) => { @@ -468,7 +547,9 @@ fn get_arithmetic_field( } }, Operator::Plus => { - let right_type = right_ae.to_field_impl(schema, ctx, arena, agg_list)?.dtype; + let right_type = right_ae + .to_field_impl(schema, ctx, arena, agg_list, validate)? + .dtype; match (&left_field.dtype, &right_type) { (Duration(_), Datetime(_, _)) | (Datetime(_, _), Duration(_)) @@ -510,7 +591,9 @@ fn get_arithmetic_field( } }, _ => { - let right_type = right_ae.to_field_impl(schema, ctx, arena, agg_list)?.dtype; + let right_type = right_ae + .to_field_impl(schema, ctx, arena, agg_list, validate)? + .dtype; match (&left_field.dtype, &right_type) { #[cfg(feature = "dtype-struct")] @@ -597,13 +680,14 @@ fn get_truediv_field( ctx: Context, schema: &Schema, agg_list: &mut bool, + validate: bool, ) -> PolarsResult { let mut left_field = arena .get(left) - .to_field_impl(schema, ctx, arena, agg_list)?; + .to_field_impl(schema, ctx, arena, agg_list, validate)?; let right_field = arena .get(right) - .to_field_impl(schema, ctx, arena, agg_list)?; + .to_field_impl(schema, ctx, arena, agg_list, validate)?; use DataType::*; // TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code diff --git a/crates/polars-plan/src/utils.rs b/crates/polars-plan/src/utils.rs index bf18cc4119d2..4bdf483474c3 100644 --- a/crates/polars-plan/src/utils.rs +++ b/crates/polars-plan/src/utils.rs @@ -357,7 +357,10 @@ pub(crate) fn expr_irs_to_schema, K: AsRef>( expr.into_iter() .map(|e| { let e = e.as_ref(); - let mut field = arena.get(e.node()).to_field(schema, ctxt, arena).unwrap(); + let mut field = arena + .get(e.node()) + .to_field(schema, ctxt, arena) + .expect("should be resolved"); if let Some(name) = e.get_alias() { field.name = name.clone() diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 43e8840458d3..640343de2a77 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -296,3 +296,14 @@ def test_lf_explode_schema() -> None: q = lf.select(pl.col("x").list.explode()) assert q.collect_schema() == {"x": pl.Int64} + + +def test_raise_subnodes_18787() -> None: + df = pl.DataFrame({"a": [1], "b": [2]}) + + with pytest.raises(pl.exceptions.ColumnNotFoundError): + ( + df.select(pl.struct(pl.all())).select( + pl.first().struct.field("a", "b").filter(pl.col("foo") == 1) + ) + )