diff --git a/velox/docs/functions/spark/aggregate.rst b/velox/docs/functions/spark/aggregate.rst index b81c05460081..b8420b4834ca 100644 --- a/velox/docs/functions/spark/aggregate.rst +++ b/velox/docs/functions/spark/aggregate.rst @@ -61,7 +61,7 @@ General Aggregate Functions .. spark:function:: collect_set(x) -> array<[same as x]> Returns an array consisting of all unique values from the input ``x`` elements excluding NULLs. - Returns empty array if input is empty or all NULL. + NaN values are considered distinct. Returns empty array if input is empty or all NULL. Example:: diff --git a/velox/exec/SetAccumulator.h b/velox/exec/SetAccumulator.h index 3783ff5d843b..79084a1dfb2f 100644 --- a/velox/exec/SetAccumulator.h +++ b/velox/exec/SetAccumulator.h @@ -523,4 +523,9 @@ template using SetAccumulator = typename detail::SetAccumulatorTypeTraits::AccumulatorType; +/// Specialization for floating point types to handle NaNs, where NaNs are +/// treated as distinct values. +template +using FloatSetAccumulatorNaNUnaware = typename detail::SetAccumulator; + } // namespace facebook::velox::aggregate::prestosql diff --git a/velox/functions/sparksql/aggregates/CollectSetAggregate.cpp b/velox/functions/sparksql/aggregates/CollectSetAggregate.cpp index 000e4603451e..15ae6dc5514f 100644 --- a/velox/functions/sparksql/aggregates/CollectSetAggregate.cpp +++ b/velox/functions/sparksql/aggregates/CollectSetAggregate.cpp @@ -24,6 +24,14 @@ namespace { template using SparkSetAggAggregate = SetAggAggregate; +// NaN inputs are treated as distinct values. +template +using FloatSetAggAggregateNaNUnaware = SetAggAggregate< + T, + true, + false, + velox::aggregate::prestosql::FloatSetAccumulatorNaNUnaware>; + } // namespace void registerCollectSetAggAggregate( @@ -72,9 +80,11 @@ void registerCollectSetAggAggregate( "Non-decimal use of HUGEINT is not supported"); return std::make_unique>(resultType); case TypeKind::REAL: - return std::make_unique>(resultType); + return std::make_unique>( + resultType); case TypeKind::DOUBLE: - return std::make_unique>(resultType); + return std::make_unique>( + resultType); case TypeKind::TIMESTAMP: return std::make_unique>( resultType); diff --git a/velox/functions/sparksql/aggregates/tests/CollectSetAggregateTest.cpp b/velox/functions/sparksql/aggregates/tests/CollectSetAggregateTest.cpp index b61c4a47bfd8..64ece852b246 100644 --- a/velox/functions/sparksql/aggregates/tests/CollectSetAggregateTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/CollectSetAggregateTest.cpp @@ -70,6 +70,7 @@ TEST_F(CollectSetAggregateTest, global) { testAggregations( {data}, {}, {"collect_set(c0)"}, {"spark_array_sort(a0)"}, {expected}); + // NaN inputs are treated as distinct values. data = makeRowVector({ makeFlatVector( {1, @@ -80,7 +81,10 @@ TEST_F(CollectSetAggregateTest, global) { expected = makeRowVector({ makeArrayVector({ - {1, std::numeric_limits::quiet_NaN()}, + {1, + std::numeric_limits::quiet_NaN(), + std::nan("1"), + std::nan("2")}, }), });