Skip to content

Commit 80e2065

Browse files
peterenescufacebook-github-bot
authored andcommitted
feat: Add Presto function array_top_n (facebookincubator#12105)
Summary: Adds Presto function array_top_n as a simple function in Velox. Function uses a temporary vector to store inputted values and heap sorts them up to k values (second input to function). Updates ArrayFunction.h with struct ArrayTopNFunction and adds new tester function ArrayTopNTest.cpp Differential Revision: D68031372
1 parent bb745c1 commit 80e2065

File tree

4 files changed

+332
-0
lines changed

4 files changed

+332
-0
lines changed

velox/functions/prestosql/ArrayFunctions.h

+50
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
#include "velox/expression/PrestoCastHooks.h"
2222
#include "velox/functions/Udf.h"
2323
#include "velox/functions/lib/CheckedArithmetic.h"
24+
#include "velox/functions/lib/ComparatorUtil.h"
2425
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
2526
#include "velox/functions/prestosql/types/JsonType.h"
2627
#include "velox/type/Conversions.h"
2728
#include "velox/type/FloatingPointUtil.h"
2829

30+
#include <queue>
31+
2932
namespace facebook::velox::functions {
3033

3134
template <typename TExecCtx, bool isMax>
@@ -729,6 +732,53 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
729732
}
730733
}
731734

735+
template <typename T>
736+
struct ArrayTopNFunction {
737+
VELOX_DEFINE_FUNCTION_TYPES(T);
738+
739+
// Definition for primitives.
740+
template <typename TReturn, typename TInput>
741+
FOLLY_ALWAYS_INLINE bool
742+
call(TReturn& result, const TInput& array, int64_t n) {
743+
// If n is invalid, exit early.
744+
if (n <= 0) {
745+
return false;
746+
}
747+
748+
// Define min-heap to store the top n elements.
749+
std::priority_queue<
750+
typename TInput::element_t,
751+
std::vector<typename TInput::element_t>,
752+
std::greater<>>
753+
minHeap;
754+
755+
// Iterate through the array and push elements to the min-heap.
756+
for (const auto& item : array) {
757+
if (item.has_value()) {
758+
minHeap.push(item.value());
759+
if (minHeap.size() > n) {
760+
minHeap.pop();
761+
}
762+
}
763+
}
764+
765+
// Reverse the min-heap to get the top n elements in descending order.
766+
std::vector<typename TInput::element_t> reversed;
767+
while (!minHeap.empty()) {
768+
reversed.push_back(minHeap.top());
769+
minHeap.pop();
770+
}
771+
std::reverse(reversed.begin(), reversed.end());
772+
773+
// Copy mutated vector to result vector up to minHeap's size items.
774+
for (const auto& item : reversed) {
775+
result.push_back(item);
776+
}
777+
778+
return true;
779+
}
780+
};
781+
732782
template <typename T>
733783
struct ArrayTrimFunction {
734784
VELOX_DEFINE_FUNCTION_TYPES(T);

velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) {
9797
{prefix + "trim_array"});
9898
}
9999

100+
template <typename T>
101+
inline void registerArrayTopNFunction(const std::string& prefix) {
102+
registerFunction<ArrayTopNFunction, Array<T>, Array<T>, int64_t>(
103+
{prefix + "array_top_n"});
104+
}
105+
100106
template <typename T>
101107
inline void registerArrayRemoveNullFunctions(const std::string& prefix) {
102108
registerFunction<ArrayRemoveNullFunction, Array<T>, Array<T>>(
@@ -241,6 +247,18 @@ void registerArrayFunctions(const std::string& prefix) {
241247
Array<Varchar>,
242248
int64_t>({prefix + "trim_array"});
243249

250+
registerArrayTopNFunction<int8_t>(prefix);
251+
registerArrayTopNFunction<int16_t>(prefix);
252+
registerArrayTopNFunction<int32_t>(prefix);
253+
registerArrayTopNFunction<int64_t>(prefix);
254+
registerArrayTopNFunction<int128_t>(prefix);
255+
registerArrayTopNFunction<float>(prefix);
256+
registerArrayTopNFunction<double>(prefix);
257+
registerArrayTopNFunction<Varchar>(prefix);
258+
registerArrayTopNFunction<Timestamp>(prefix);
259+
registerArrayTopNFunction<Date>(prefix);
260+
registerArrayTopNFunction<Varbinary>(prefix);
261+
244262
registerArrayRemoveNullFunctions<int8_t>(prefix);
245263
registerArrayRemoveNullFunctions<int16_t>(prefix);
246264
registerArrayRemoveNullFunctions<int32_t>(prefix);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,263 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include "velox/common/base/tests/GTestUtils.h"
17+
#include "velox/functions/Macros.h"
18+
#include "velox/functions/Registerer.h"
19+
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
20+
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
21+
22+
#include <fmt/format.h>
23+
#include <cstdint>
24+
25+
using namespace facebook::velox;
26+
using namespace facebook::velox::test;
27+
using facebook::velox::functions::test::FunctionBaseTest;
28+
using namespace facebook::velox::functions::test;
29+
30+
namespace {
31+
32+
class ArrayTopNTest : public FunctionBaseTest {};
33+
34+
TEST_F(ArrayTopNTest, jsonHappyPath) {
35+
auto input = makeArrayVectorFromJson<int32_t>({
36+
"[1, 2, 3]",
37+
"[4, 5, 6]",
38+
"[7, 8, 9]",
39+
});
40+
41+
auto expected_result =
42+
makeArrayVectorFromJson<int32_t>({"[3]", "[6]", "[9]"});
43+
auto result = evaluate("array_top_n(c0, 1)", makeRowVector({input}));
44+
assertEqualVectors(expected_result, result);
45+
46+
expected_result =
47+
makeArrayVectorFromJson<int32_t>({"[3, 2]", "[6, 5]", "[9, 8]"});
48+
result = evaluate("array_top_n(c0, 2)", makeRowVector({input}));
49+
assertEqualVectors(expected_result, result);
50+
51+
expected_result =
52+
makeArrayVectorFromJson<int32_t>({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"});
53+
result = evaluate("array_top_n(c0, 3)", makeRowVector({input}));
54+
assertEqualVectors(expected_result, result);
55+
56+
result = evaluate("array_top_n(c0, 5)", makeRowVector({input}));
57+
assertEqualVectors(expected_result, result);
58+
}
59+
60+
TEST_F(ArrayTopNTest, nullHandler) {
61+
// Test fully null array vector.
62+
auto input = makeNullableArrayVector<int32_t>({
63+
{std::nullopt, std::nullopt},
64+
{std::nullopt, std::nullopt, std::nullopt},
65+
});
66+
auto expected = makeArrayVectorFromJson<int32_t>({"[]", "[]"});
67+
auto result = evaluate("array_top_n(c0, 2)", makeRowVector({input}));
68+
assertEqualVectors(expected, result);
69+
70+
// Test null array vector with various different top n values.
71+
input = makeArrayVectorFromJson<int32_t>({
72+
"[1, null, 2, null, 3]",
73+
"[4, 5, null, 6, null]",
74+
"[null, 7, null, 8, 9]",
75+
});
76+
77+
expected = makeArrayVectorFromJson<int32_t>({"[3]", "[6]", "[9]"});
78+
result = evaluate("array_top_n(c0, 1)", makeRowVector({input}));
79+
assertEqualVectors(expected, result);
80+
81+
expected = makeArrayVectorFromJson<int32_t>({"[3, 2]", "[6, 5]", "[9, 8]"});
82+
result = evaluate("array_top_n(c0, 2)", makeRowVector({input}));
83+
assertEqualVectors(expected, result);
84+
85+
expected =
86+
makeArrayVectorFromJson<int32_t>({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"});
87+
result = evaluate("array_top_n(c0, 3)", makeRowVector({input}));
88+
assertEqualVectors(expected, result);
89+
90+
expected =
91+
makeArrayVectorFromJson<int32_t>({"[3, 2, 1]", "[6, 5, 4]", "[9, 8, 7]"});
92+
result = evaluate("array_top_n(c0, 4)", makeRowVector({input}));
93+
assertEqualVectors(expected, result);
94+
95+
// Test nullable aray vector of bigints.
96+
input = makeNullableArrayVector<int64_t>(
97+
{{1, 2, std::nullopt},
98+
{4, 5, std::nullopt, std::nullopt},
99+
{7, std::nullopt, std::nullopt, std::nullopt}});
100+
101+
expected = makeArrayVectorFromJson<int64_t>({"[2, 1]", "[5, 4]", "[7]"});
102+
result = evaluate("array_top_n(c0, 3)", makeRowVector({input}));
103+
assertEqualVectors(expected, result);
104+
105+
// Test nullable aray vector of strings.
106+
input = makeNullableArrayVector<std::string>({
107+
{"abc123", "abc", std::nullopt, "abcd"},
108+
{std::nullopt, "x", "xyz123", "xyzzzz"},
109+
});
110+
expected = makeArrayVectorFromJson<std::string>(
111+
{"[\"abcd\", \"abc123\", \"abc\"]", "[\"xyzzzz\", \"xyz123\", \"x\"]"});
112+
113+
result = evaluate("array_top_n(c0, 3)", makeRowVector({input}));
114+
assertEqualVectors(expected, result);
115+
result = evaluate("array_top_n(c0, 4)", makeRowVector({input}));
116+
assertEqualVectors(expected, result);
117+
}
118+
119+
TEST_F(ArrayTopNTest, constant) {
120+
// Test constant array vector and verify per row.
121+
vector_size_t size = 1'000;
122+
auto data = makeArrayVector<int64_t>({{1, 2, 3}, {4, 5, 4, 5}, {7, 7, 7, 7}});
123+
124+
auto evaluateConstant = [&](vector_size_t row, const VectorPtr& vector) {
125+
return evaluate(
126+
"array_top_n(c0, 2)",
127+
makeRowVector({BaseVector::wrapInConstant(size, row, vector)}));
128+
};
129+
130+
auto result = evaluateConstant(0, data);
131+
auto expected = makeConstantArray<int64_t>(size, {3, 2});
132+
assertEqualVectors(expected, result);
133+
134+
result = evaluateConstant(1, data);
135+
expected = makeConstantArray<int64_t>(size, {5, 5});
136+
assertEqualVectors(expected, result);
137+
138+
result = evaluateConstant(2, data);
139+
expected = makeConstantArray<int64_t>(size, {7, 7});
140+
assertEqualVectors(expected, result);
141+
142+
data = makeArrayVector<int64_t>(
143+
{{1, 2, 3, 0, 1, 2, 2}, {4, 5, 4, 5, 5, 4}, {6, 6, 6, 6, 7, 8, 9, 10}});
144+
145+
auto evaluateMore = [&](vector_size_t row, const VectorPtr& vector) {
146+
return evaluate(
147+
"array_top_n(c0, 3)",
148+
makeRowVector({BaseVector::wrapInConstant(size, row, vector)}));
149+
};
150+
151+
result = evaluateMore(0, data);
152+
expected = makeConstantArray<int64_t>(size, {3, 2, 2});
153+
assertEqualVectors(expected, result);
154+
155+
result = evaluateMore(1, data);
156+
expected = makeConstantArray<int64_t>(size, {5, 5, 5});
157+
assertEqualVectors(expected, result);
158+
159+
result = evaluateMore(2, data);
160+
expected = makeConstantArray<int64_t>(size, {10, 9, 8});
161+
assertEqualVectors(expected, result);
162+
}
163+
164+
TEST_F(ArrayTopNTest, inlineStringArrays) {
165+
// Test inline (short) strings.
166+
using S = StringView;
167+
168+
auto input = makeNullableArrayVector<StringView>({
169+
{},
170+
{S("")},
171+
{std::nullopt},
172+
{S("a"), S("b")},
173+
{S("a"), std::nullopt, S("b")},
174+
{S("a"), S("a")},
175+
{S("b"), S("a"), S("b"), S("a"), S("a")},
176+
{std::nullopt, std::nullopt},
177+
{S("b"), std::nullopt, S("a"), S("a"), std::nullopt, S("b")},
178+
});
179+
180+
auto expected = makeNullableArrayVector<StringView>({
181+
{},
182+
{S("")},
183+
{},
184+
{S("b"), S("a")},
185+
{S("b"), S("a")},
186+
{S("a"), S("a")},
187+
{S("b"), S("b")},
188+
{},
189+
{S("b"), S("b")},
190+
});
191+
192+
auto result =
193+
evaluate<ArrayVector>("array_top_n(C0, 2)", makeRowVector({input}));
194+
assertEqualVectors(expected, result);
195+
}
196+
197+
TEST_F(ArrayTopNTest, stringArrays) {
198+
// Test non-inline (> 12 character length) strings.
199+
using S = StringView;
200+
201+
auto input = makeNullableArrayVector<StringView>({
202+
{S("red shiny car ahead"), S("blue clear sky above")},
203+
{S("blue clear sky above"),
204+
S("yellow rose flowers"),
205+
std::nullopt,
206+
S("blue clear sky above"),
207+
S("orange beautiful sunset")},
208+
{
209+
S("red shiny car ahead"),
210+
std::nullopt,
211+
S("purple is an elegant color"),
212+
S("red shiny car ahead"),
213+
S("green plants make us happy"),
214+
S("purple is an elegant color"),
215+
std::nullopt,
216+
S("purple is an elegant color"),
217+
},
218+
});
219+
220+
auto expected = makeNullableArrayVector<StringView>({
221+
{S("red shiny car ahead"), S("blue clear sky above")},
222+
{S("yellow rose flowers"),
223+
S("orange beautiful sunset"),
224+
S("blue clear sky above")},
225+
{S("red shiny car ahead"),
226+
S("red shiny car ahead"),
227+
S("purple is an elegant color")},
228+
});
229+
230+
auto result =
231+
evaluate<ArrayVector>("array_top_n(C0, 3)", makeRowVector({input}));
232+
assertEqualVectors(expected, result);
233+
}
234+
235+
TEST_F(ArrayTopNTest, nonContiguousRows) {
236+
auto c0 = makeFlatVector<int64_t>(4, [](auto row) { return row; });
237+
auto c1 = makeArrayVector<int64_t>({
238+
{1, 1, 2, 3, 3},
239+
{1, 1, 2, 3, 4, 4},
240+
{1, 1, 2, 3, 4, 5, 5},
241+
{1, 1, 2, 3, 3, 4, 5, 6, 6},
242+
});
243+
244+
auto c2 = makeArrayVector<int64_t>({
245+
{0, 0, 1, 1, 2, 3, 3},
246+
{0, 0, 1, 1, 2, 3, 4, 4},
247+
{0, 0, 1, 1, 2, 3, 4, 5, 5},
248+
{0, 0, 1, 1, 2, 3, 4, 5, 6, 6},
249+
});
250+
251+
auto expected = makeArrayVector<int64_t>({
252+
{3, 3},
253+
{4, 4},
254+
{5, 5},
255+
{6, 6},
256+
});
257+
258+
auto result = evaluate<ArrayVector>(
259+
"if(c0 % 2 = 0, array_top_n(c1, 2), array_top_n(c2, 2))",
260+
makeRowVector({c0, c1, c2}));
261+
assertEqualVectors(expected, result);
262+
}
263+
} // namespace

velox/functions/prestosql/tests/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ add_executable(
4444
ArrayRemoveTest.cpp
4545
ArrayShuffleTest.cpp
4646
ArraySortTest.cpp
47+
ArrayTopNTest.cpp
4748
ArraysOverlapTest.cpp
4849
ArraySumTest.cpp
4950
ArrayTrimTest.cpp

0 commit comments

Comments
 (0)