diff --git a/velox/core/QueryConfig.h b/velox/core/QueryConfig.h index 2f6c3e55cfd6..ed565bb40f4f 100644 --- a/velox/core/QueryConfig.h +++ b/velox/core/QueryConfig.h @@ -333,6 +333,12 @@ class QueryConfig { static constexpr const char* kSparkLegacyDateFormatter = "spark.legacy_date_formatter"; + /// If true, statistical aggregation function includes skewness, kurtosis, + /// will return NaN instead of NULL when dividing by zero during expression + /// evaluation. + static constexpr const char* kSparkLegacyStatisticalAggregate = + "spark.legacy_statistical_aggregate"; + /// The number of local parallel table writer operators per task. static constexpr const char* kTaskWriterCount = "task_writer_count"; @@ -831,6 +837,10 @@ class QueryConfig { return get(kSparkLegacyDateFormatter, false); } + bool sparkLegacyStatisticalAggregate() const { + return get(kSparkLegacyStatisticalAggregate, false); + } + bool exprTrackCpuUsage() const { return get(kExprTrackCpuUsage, false); } diff --git a/velox/docs/configs.rst b/velox/docs/configs.rst index 3f89eded8389..7f0a72eb41f1 100644 --- a/velox/docs/configs.rst +++ b/velox/docs/configs.rst @@ -887,6 +887,13 @@ Spark-specific Configuration Joda date formatter performs strict checking of its input and uses different pattern string. For example, the 2015-07-22 10:00:00 timestamp cannot be parsed if pattern is yyyy-MM-dd because the parser does not consume whole input. Another example is that the 'W' pattern, which means week in month, is not supported. For more differences, see :issue:`10354`. + * - spark.legacy_statistical_aggregate + - bool + - false + - If true, statistical aggregation function includes skewness, kurtosis will return std::numeric_limits::quiet_NaN() + - instead of NULL when dividing by zero during expression evaluation. It is worth noting that Spark statistical aggregation functions + - including stddev, stddev_samp, variance, var_samp, covar_samp, corr should also respect this configuration, + - although they have not been supported yet. Tracing -------- diff --git a/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp index 90a0ade6f0e7..7afe2bf71395 100644 --- a/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp +++ b/velox/functions/sparksql/aggregates/CentralMomentsAggregate.cpp @@ -15,28 +15,57 @@ */ #include "velox/functions/sparksql/aggregates/CentralMomentsAggregate.h" +#include #include "velox/functions/lib/aggregates/CentralMomentsAggregatesBase.h" namespace facebook::velox::functions::aggregate::sparksql { namespace { +// Calculate the skewness value from m2, count and m3. +// +// @tparam nullOnDivideByZero If true, return NULL instead of NaN when dividing +// by zero during the calculating. +template struct SkewnessResultAccessor { static bool hasResult(const CentralMomentsAccumulator& accumulator) { - return accumulator.count() >= 1 && accumulator.m2() != 0; + if constexpr (nullOnDivideByZero) { + return accumulator.count() >= 1 && accumulator.m2() != 0; + } + return accumulator.count() >= 1; } static double result(const CentralMomentsAccumulator& accumulator) { + if (accumulator.m2() == 0) { + VELOX_USER_CHECK( + !nullOnDivideByZero, + "If NaN is returned when m2 is 0, nullOnDivideByZero must be false"); + return std::numeric_limits::quiet_NaN(); + } return std::sqrt(accumulator.count()) * accumulator.m3() / std::pow(accumulator.m2(), 1.5); } }; +// Calculate the kurtosis value from m2, count and m4. +// +// @tparam nullOnDivideByZero If true, return NULL instead of NaN when dividing +// by zero during the calculating. +template struct KurtosisResultAccessor { static bool hasResult(const CentralMomentsAccumulator& accumulator) { - return accumulator.count() >= 1 && accumulator.m2() != 0; + if constexpr (nullOnDivideByZero) { + return accumulator.count() >= 1 && accumulator.m2() != 0; + } + return accumulator.count() >= 1; } static double result(const CentralMomentsAccumulator& accumulator) { + if (accumulator.m2() == 0) { + VELOX_USER_CHECK( + !nullOnDivideByZero, + "If NaN is returned when m2 is 0, nullOnDivideByZero must be false"); + return std::numeric_limits::quiet_NaN(); + } double count = accumulator.count(); double m2 = accumulator.m2(); double m4 = accumulator.m4(); @@ -44,22 +73,23 @@ struct KurtosisResultAccessor { } }; -template -exec::AggregateRegistrationResult registerCentralMoments( +std::vector> getSignatures() { + std::vector> signatures; + signatures.push_back( + exec::AggregateFunctionSignatureBuilder() + .returnType("double") + .intermediateType(CentralMomentsIntermediateResult::type()) + .argumentType("double") + .build()); + return signatures; +} + +exec::AggregateRegistrationResult registerSkewness( const std::string& name, bool withCompanionFunctions, bool overwrite) { - std::vector> signatures; - std::vector inputTypes = { - "smallint", "integer", "bigint", "real", "double"}; - for (const auto& inputType : inputTypes) { - signatures.push_back( - exec::AggregateFunctionSignatureBuilder() - .returnType("double") - .intermediateType(CentralMomentsIntermediateResult::type()) - .argumentType(inputType) - .build()); - } + std::vector> signatures = + getSignatures(); return exec::registerAggregateFunction( name, @@ -68,47 +98,118 @@ exec::AggregateRegistrationResult registerCentralMoments( core::AggregationNode::Step step, const std::vector& argTypes, const TypePtr& resultType, - const core::QueryConfig& /*config*/) - -> std::unique_ptr { - VELOX_CHECK_LE( - argTypes.size(), 1, "{} takes at most one argument", name); + const core::QueryConfig& config) -> std::unique_ptr { + VELOX_CHECK_EQ(argTypes.size(), 1, "{} takes only one argument", name); const auto& inputType = argTypes[0]; - if (exec::isRawInput(step)) { - switch (inputType->kind()) { - case TypeKind::SMALLINT: - return std::make_unique< - CentralMomentsAggregatesBase>( - resultType); - case TypeKind::INTEGER: - return std::make_unique< - CentralMomentsAggregatesBase>( - resultType); - case TypeKind::BIGINT: - return std::make_unique< - CentralMomentsAggregatesBase>( - resultType); - case TypeKind::DOUBLE: - return std::make_unique< - CentralMomentsAggregatesBase>( - resultType); - case TypeKind::REAL: - return std::make_unique< - CentralMomentsAggregatesBase>( - resultType); - default: + if (config.sparkLegacyStatisticalAggregate()) { + if (exec::isRawInput(step)) { + if (inputType->kind() == TypeKind::DOUBLE) { + return std::make_unique>>(resultType); + } else { VELOX_UNSUPPORTED( "Unsupported input type: {}. " - "Expected SMALLINT, INTEGER, BIGINT, DOUBLE or REAL.", + "Expected DOUBLE.", inputType->toString()); + } + } else { + checkAccumulatorRowType( + inputType, + "Input type for final aggregation must be " + "(count:bigint, m1:double, m2:double, m3:double, m4:double) struct"); + return std::make_unique>>(resultType); + } + } else { + if (exec::isRawInput(step)) { + if (inputType->kind() == TypeKind::DOUBLE) { + return std::make_unique>>(resultType); + } else { + VELOX_UNSUPPORTED( + "Unsupported input type: {}. " + "Expected DOUBLE.", + inputType->toString()); + } + } else { + checkAccumulatorRowType( + inputType, + "Input type for final aggregation must be " + "(count:bigint, m1:double, m2:double, m3:double, m4:double) struct"); + return std::make_unique>>(resultType); + } + } + }, + withCompanionFunctions, + overwrite); +} + +exec::AggregateRegistrationResult registerKurtosis( + const std::string& name, + bool withCompanionFunctions, + bool overwrite) { + std::vector> signatures = + getSignatures(); + + return exec::registerAggregateFunction( + name, + std::move(signatures), + [name]( + core::AggregationNode::Step step, + const std::vector& argTypes, + const TypePtr& resultType, + const core::QueryConfig& config) -> std::unique_ptr { + VELOX_CHECK_EQ(argTypes.size(), 1, "{} takes only one argument", name); + const auto& inputType = argTypes[0]; + if (config.sparkLegacyStatisticalAggregate()) { + if (exec::isRawInput(step)) { + switch (inputType->kind()) { + case TypeKind::DOUBLE: + return std::make_unique>>(resultType); + default: + VELOX_UNSUPPORTED( + "Unsupported input type: {}. " + "Expected SMALLINT, INTEGER, BIGINT, DOUBLE or REAL.", + inputType->toString()); + } + } else { + checkAccumulatorRowType( + inputType, + "Input type for final aggregation must be " + "(count:bigint, m1:double, m2:double, m3:double, m4:double) struct"); + return std::make_unique>>(resultType); } } else { - checkAccumulatorRowType( - inputType, - "Input type for final aggregation must be " - "(count:bigint, m1:double, m2:double, m3:double, m4:double) struct"); - return std::make_unique>(resultType); + if (exec::isRawInput(step)) { + switch (inputType->kind()) { + case TypeKind::DOUBLE: + return std::make_unique>>(resultType); + default: + VELOX_UNSUPPORTED( + "Unsupported input type: {}. " + "Expected SMALLINT, INTEGER, BIGINT, DOUBLE or REAL.", + inputType->toString()); + } + } else { + checkAccumulatorRowType( + inputType, + "Input type for final aggregation must be " + "(count:bigint, m1:double, m2:double, m3:double, m4:double) struct"); + return std::make_unique>>(resultType); + } } }, withCompanionFunctions, @@ -120,10 +221,8 @@ void registerCentralMomentsAggregate( const std::string& prefix, bool withCompanionFunctions, bool overwrite) { - registerCentralMoments( - prefix + "skewness", withCompanionFunctions, overwrite); - registerCentralMoments( - prefix + "kurtosis", withCompanionFunctions, overwrite); + registerSkewness(prefix + "skewness", withCompanionFunctions, overwrite); + registerKurtosis(prefix + "kurtosis", withCompanionFunctions, overwrite); } } // namespace facebook::velox::functions::aggregate::sparksql diff --git a/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp index 9f5dd9b3efd8..eef28d501c7c 100644 --- a/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/CentralMomentsAggregationTest.cpp @@ -40,44 +40,76 @@ class CentralMomentsAggregationTest : public AggregationTestBase { builder.singleAggregation({}, {fmt::format("spark_{}(c0)", agg)}); AssertQueryBuilder(builder.planNode()).assertResults({expected}); } + + void testLegacyCenteralMomentsAggResult( + const std::string& agg, + const RowVectorPtr& input, + const RowVectorPtr& expected) { + PlanBuilder builder(pool()); + builder.values({input}); + builder.singleAggregation({}, {fmt::format("spark_{}(c0)", agg)}); + AssertQueryBuilder(builder.planNode()) + .config("spark.legacy_statistical_aggregate", "true") + .assertResults({expected}); + } }; TEST_F(CentralMomentsAggregationTest, skewnessHasResult) { auto agg = "skewness"; - auto input = makeRowVector({makeFlatVector({1, 2})}); + auto input = makeRowVector({makeFlatVector({1, 2})}); // Even when the count is 2, Spark still produces output. auto expected = makeRowVector({makeFlatVector(std::vector{0.0})}); testCenteralMomentsAggResult(agg, input, expected); - input = makeRowVector({makeFlatVector({1, 1})}); + input = makeRowVector({makeFlatVector({1, 1})}); expected = makeRowVector({makeNullableFlatVector( std::vector>{std::nullopt})}); testCenteralMomentsAggResult(agg, input, expected); + + // Output NULL when m2 equals 0. + input = makeRowVector({makeFlatVector({1, 1})}); + expected = makeRowVector({makeNullableFlatVector( + std::vector>{std::nullopt})}); + testCenteralMomentsAggResult(agg, input, expected); + + // Output NaN when m2 equals 0 for legacy aggregate. + input = makeRowVector({makeFlatVector({1, 1})}); + expected = makeRowVector( + {makeNullableFlatVector(std::vector>{ + std::numeric_limits::quiet_NaN()})}); + testLegacyCenteralMomentsAggResult(agg, input, expected); } TEST_F(CentralMomentsAggregationTest, pearsonKurtosis) { auto agg = "kurtosis"; - auto input = makeRowVector({makeFlatVector({1, 10, 100, 10, 1})}); + auto input = makeRowVector({makeFlatVector({1, 10, 100, 10, 1})}); auto expected = makeRowVector( {makeFlatVector(std::vector{0.19432323191699075})}); testCenteralMomentsAggResult(agg, input, expected); - input = makeRowVector({makeFlatVector({-10, -20, 100, 1000})}); + input = makeRowVector({makeFlatVector({-10, -20, 100, 1000})}); expected = makeRowVector( {makeFlatVector(std::vector{-0.7014368047529627})}); testCenteralMomentsAggResult(agg, input, expected); // Even when the count is 2, Spark still produces non-null result. - input = makeRowVector({makeFlatVector({1, 2})}); + input = makeRowVector({makeFlatVector({1, 2})}); expected = makeRowVector({makeFlatVector(std::vector{-2.0})}); testCenteralMomentsAggResult(agg, input, expected); // Output NULL when m2 equals 0. - input = makeRowVector({makeFlatVector({1, 1})}); + input = makeRowVector({makeFlatVector({1, 1})}); expected = makeRowVector({makeNullableFlatVector( std::vector>{std::nullopt})}); testCenteralMomentsAggResult(agg, input, expected); + + // Output NaN when m2 equals 0 for legacy aggregate. + input = makeRowVector({makeFlatVector({1, 1})}); + expected = makeRowVector( + {makeNullableFlatVector(std::vector>{ + std::numeric_limits::quiet_NaN()})}); + testLegacyCenteralMomentsAggResult(agg, input, expected); } } // namespace