Skip to content

Commit 9c5f951

Browse files
committed
Fix NaN in collect_set
1 parent 8ed7b0b commit 9c5f951

File tree

4 files changed

+23
-4
lines changed

4 files changed

+23
-4
lines changed

velox/docs/functions/spark/aggregate.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ General Aggregate Functions
6161
.. spark:function:: collect_set(x) -> array<[same as x]>
6262
6363
Returns an array consisting of all unique values from the input ``x`` elements excluding NULLs.
64-
Returns empty array if input is empty or all NULL.
64+
NaN values are considered distinct. Returns empty array if input is empty or all NULL.
6565

6666
Example::
6767

velox/exec/SetAccumulator.h

+5
Original file line numberDiff line numberDiff line change
@@ -523,4 +523,9 @@ template <typename T>
523523
using SetAccumulator =
524524
typename detail::SetAccumulatorTypeTraits<T>::AccumulatorType;
525525

526+
/// Specialization for floating point types to handle NaNs, where NaNs are
527+
/// treated as distinct values.
528+
template <typename T>
529+
using FloatSetAccumulatorNaNUnaware = typename detail::SetAccumulator<T>;
530+
526531
} // namespace facebook::velox::aggregate::prestosql

velox/functions/sparksql/aggregates/CollectSetAggregate.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ namespace {
2424
template <typename T>
2525
using SparkSetAggAggregate = SetAggAggregate<T, true, false>;
2626

27+
// NaN inputs are treated as distinct values.
28+
template <typename T>
29+
using FloatSetAggAggregateNaNUnaware = SetAggAggregate<
30+
T,
31+
true,
32+
false,
33+
velox::aggregate::prestosql::FloatSetAccumulatorNaNUnaware<T>>;
34+
2735
} // namespace
2836

2937
void registerCollectSetAggAggregate(
@@ -72,9 +80,11 @@ void registerCollectSetAggAggregate(
7280
"Non-decimal use of HUGEINT is not supported");
7381
return std::make_unique<SparkSetAggAggregate<int128_t>>(resultType);
7482
case TypeKind::REAL:
75-
return std::make_unique<SparkSetAggAggregate<float>>(resultType);
83+
return std::make_unique<FloatSetAggAggregateNaNUnaware<float>>(
84+
resultType);
7685
case TypeKind::DOUBLE:
77-
return std::make_unique<SparkSetAggAggregate<double>>(resultType);
86+
return std::make_unique<FloatSetAggAggregateNaNUnaware<double>>(
87+
resultType);
7888
case TypeKind::TIMESTAMP:
7989
return std::make_unique<SparkSetAggAggregate<Timestamp>>(
8090
resultType);

velox/functions/sparksql/aggregates/tests/CollectSetAggregateTest.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ TEST_F(CollectSetAggregateTest, global) {
7070
testAggregations(
7171
{data}, {}, {"collect_set(c0)"}, {"spark_array_sort(a0)"}, {expected});
7272

73+
// NaN inputs are treated as distinct values.
7374
data = makeRowVector({
7475
makeFlatVector<double>(
7576
{1,
@@ -80,7 +81,10 @@ TEST_F(CollectSetAggregateTest, global) {
8081

8182
expected = makeRowVector({
8283
makeArrayVector<double>({
83-
{1, std::numeric_limits<double>::quiet_NaN()},
84+
{1,
85+
std::numeric_limits<double>::quiet_NaN(),
86+
std::nan("1"),
87+
std::nan("2")},
8488
}),
8589
});
8690

0 commit comments

Comments
 (0)