Skip to content

Commit e533236

Browse files
peterenescufacebook-github-bot
authored andcommitted
feat(array): Add Presto function array_top_n (facebookincubator#12105)
Summary: Adds Presto function array_top_n as a simple function in Velox. Function uses a temporary vector to store inputted values and heap sorts them up to k values (second input to function). Updates ArrayFunction.h with struct ArrayTopNFunction and adds new tester function ArrayTopNTest.cpp Differential Revision: D68031372
1 parent 5fa90be commit e533236

File tree

4 files changed

+620
-1
lines changed

4 files changed

+620
-1
lines changed

velox/functions/prestosql/ArrayFunctions.h

+153-1
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,14 @@
2121
#include "velox/expression/PrestoCastHooks.h"
2222
#include "velox/functions/Udf.h"
2323
#include "velox/functions/lib/CheckedArithmetic.h"
24+
#include "velox/functions/lib/ComparatorUtil.h"
2425
#include "velox/functions/prestosql/json/SIMDJsonUtil.h"
2526
#include "velox/functions/prestosql/types/JsonType.h"
2627
#include "velox/type/Conversions.h"
2728
#include "velox/type/FloatingPointUtil.h"
2829

30+
#include <queue>
31+
2932
namespace facebook::velox::functions {
3033

3134
template <typename TExecCtx, bool isMax>
@@ -729,13 +732,162 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
729732
}
730733
}
731734

735+
/// This class implements the array_top_n function.
736+
///
737+
/// DEFINITION:
738+
/// array_top_n(array(T), int) -> array(T)
739+
/// Returns the top n elements of the array in descending order.
740+
template <typename T>
741+
struct ArrayTopNFunction {
742+
VELOX_DEFINE_FUNCTION_TYPES(T);
743+
744+
// Definition for primitives.
745+
template <typename TReturn, typename TInput>
746+
FOLLY_ALWAYS_INLINE void
747+
call(TReturn& result, const TInput& array, int32_t n) {
748+
VELOX_CHECK(
749+
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));
750+
751+
// If top n is zero or input array is empty then exit early.
752+
if (n == 0 || array.size() == 0) {
753+
return;
754+
}
755+
756+
// Define comparator that wraps built-in function for basic primitives or
757+
// calls floating point handler for NaNs.
758+
using facebook::velox::util::floating_point::NaNAwareGreaterThan;
759+
struct GreaterThanComparator {
760+
bool operator()(
761+
const typename TInput::element_t& a,
762+
const typename TInput::element_t& b) const {
763+
if constexpr (
764+
std::is_same_v<typename TInput::element_t, float> ||
765+
std::is_same_v<typename TInput::element_t, double>) {
766+
return NaNAwareGreaterThan<typename TInput::element_t>{}(a, b);
767+
} else {
768+
return std::greater<typename TInput::element_t>{}(a, b);
769+
}
770+
}
771+
};
772+
773+
// Define min-heap to store the top n elements.
774+
std::priority_queue<
775+
typename TInput::element_t,
776+
std::vector<typename TInput::element_t>,
777+
GreaterThanComparator>
778+
minHeap;
779+
780+
// Iterate through the array and push elements to the min-heap.
781+
GreaterThanComparator comparator;
782+
int numNull = 0;
783+
for (const auto& item : array) {
784+
if (item.has_value()) {
785+
if (minHeap.size() < n) {
786+
minHeap.push(item.value());
787+
} else if (comparator(item.value(), minHeap.top())) {
788+
minHeap.push(item.value());
789+
minHeap.pop();
790+
}
791+
} else {
792+
++numNull;
793+
}
794+
}
795+
796+
// Reverse the min-heap to get the top n elements in descending order.
797+
std::vector<typename TInput::element_t> reversed(minHeap.size());
798+
auto index = minHeap.size();
799+
while (!minHeap.empty()) {
800+
reversed[--index] = minHeap.top();
801+
minHeap.pop();
802+
}
803+
804+
// Copy mutated vector to result vector up to minHeap's size items.
805+
for (const auto& item : reversed) {
806+
result.push_back(item);
807+
}
808+
809+
// Backfill nulls if needed.
810+
while (result.size() < n && numNull > 0) {
811+
result.add_null();
812+
--numNull;
813+
}
814+
}
815+
816+
// Generic implementation.
817+
FOLLY_ALWAYS_INLINE void call(
818+
out_type<Array<Orderable<T1>>>& result,
819+
const arg_type<Array<Orderable<T1>>>& array,
820+
const int32_t n) {
821+
VELOX_CHECK(
822+
n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n));
823+
824+
// If top n is zero or input array is empty then exit early.
825+
if (n == 0 || array.size() == 0) {
826+
return;
827+
}
828+
829+
// Define comparator to compare complex types.
830+
struct ComplexTypeComparator {
831+
const arg_type<Array<Orderable<T1>>>& array;
832+
ComplexTypeComparator(const arg_type<Array<Orderable<T1>>>& array)
833+
: array(array) {}
834+
835+
bool operator()(const int32_t& a, const int32_t& b) const {
836+
static constexpr CompareFlags kFlags = {
837+
.nullHandlingMode =
838+
CompareFlags::NullHandlingMode::kNullAsIndeterminate};
839+
return array[a].value().compare(array[b].value(), kFlags).value() > 0;
840+
}
841+
};
842+
843+
// Define min-heap to store the top n elements.
844+
std::priority_queue<int32_t, std::vector<int32_t>, ComplexTypeComparator>
845+
minHeap(array);
846+
847+
// Iterate through the array and push elements to the min-heap.
848+
ComplexTypeComparator comparator(array);
849+
int numNull = 0;
850+
for (int i = 0; i < array.size(); ++i) {
851+
if (array[i].has_value()) {
852+
if (minHeap.size() < n) {
853+
minHeap.push(i);
854+
} else if (comparator(i, minHeap.top())) {
855+
minHeap.push(i);
856+
minHeap.pop();
857+
}
858+
} else {
859+
++numNull;
860+
}
861+
}
862+
863+
// Reverse the min-heap to get the top n elements in descending order.
864+
std::vector<int32_t> reversed(minHeap.size());
865+
auto index = minHeap.size();
866+
while (!minHeap.empty()) {
867+
reversed[--index] = minHeap.top();
868+
minHeap.pop();
869+
}
870+
871+
// Copy mutated vector to result vector up to minHeap's size items.
872+
for (const auto& index : reversed) {
873+
result.push_back(array[index].value());
874+
}
875+
876+
// Backfill nulls if needed.
877+
while (result.size() < n && numNull > 0) {
878+
result.add_null();
879+
--numNull;
880+
}
881+
}
882+
};
883+
732884
template <typename T>
733885
struct ArrayTrimFunction {
734886
VELOX_DEFINE_FUNCTION_TYPES(T);
735887

736888
// Fast path for primitives.
737889
template <typename Out, typename In>
738-
void call(Out& out, const In& inputArray, int64_t size) {
890+
void call(Out& out, const In& inputArray, int32_t size) {
739891
checkIndexArrayTrim(size, inputArray.size());
740892

741893
int64_t end = inputArray.size() - size;

velox/functions/prestosql/registration/ArrayFunctionsRegistration.cpp

+19
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,12 @@ inline void registerArrayTrimFunctions(const std::string& prefix) {
9797
{prefix + "trim_array"});
9898
}
9999

100+
template <typename T>
101+
inline void registerArrayTopNFunction(const std::string& prefix) {
102+
registerFunction<ArrayTopNFunction, Array<T>, Array<T>, int32_t>(
103+
{prefix + "array_top_n"});
104+
}
105+
100106
template <typename T>
101107
inline void registerArrayRemoveNullFunctions(const std::string& prefix) {
102108
registerFunction<ArrayRemoveNullFunction, Array<T>, Array<T>>(
@@ -241,6 +247,19 @@ void registerArrayFunctions(const std::string& prefix) {
241247
Array<Varchar>,
242248
int64_t>({prefix + "trim_array"});
243249

250+
registerArrayTopNFunction<int8_t>(prefix);
251+
registerArrayTopNFunction<int16_t>(prefix);
252+
registerArrayTopNFunction<int32_t>(prefix);
253+
registerArrayTopNFunction<int64_t>(prefix);
254+
registerArrayTopNFunction<int128_t>(prefix);
255+
registerArrayTopNFunction<float>(prefix);
256+
registerArrayTopNFunction<double>(prefix);
257+
registerArrayTopNFunction<Varchar>(prefix);
258+
registerArrayTopNFunction<Timestamp>(prefix);
259+
registerArrayTopNFunction<Date>(prefix);
260+
registerArrayTopNFunction<Varbinary>(prefix);
261+
registerArrayTopNFunction<Orderable<T1>>(prefix);
262+
244263
registerArrayRemoveNullFunctions<int8_t>(prefix);
245264
registerArrayRemoveNullFunctions<int16_t>(prefix);
246265
registerArrayRemoveNullFunctions<int32_t>(prefix);

0 commit comments

Comments
 (0)