|
21 | 21 | #include "velox/expression/PrestoCastHooks.h"
|
22 | 22 | #include "velox/functions/Udf.h"
|
23 | 23 | #include "velox/functions/lib/CheckedArithmetic.h"
|
| 24 | +#include "velox/functions/lib/ComparatorUtil.h" |
24 | 25 | #include "velox/functions/prestosql/json/SIMDJsonUtil.h"
|
25 | 26 | #include "velox/functions/prestosql/types/JsonType.h"
|
26 | 27 | #include "velox/type/Conversions.h"
|
27 | 28 | #include "velox/type/FloatingPointUtil.h"
|
28 | 29 |
|
| 30 | +#include <queue> |
| 31 | + |
29 | 32 | namespace facebook::velox::functions {
|
30 | 33 |
|
31 | 34 | template <typename TExecCtx, bool isMax>
|
@@ -729,13 +732,162 @@ inline void checkIndexArrayTrim(int64_t size, int64_t arraySize) {
|
729 | 732 | }
|
730 | 733 | }
|
731 | 734 |
|
| 735 | +/// This class implements the array_top_n function. |
| 736 | +/// |
| 737 | +/// DEFINITION: |
| 738 | +/// array_top_n(array(T), int) -> array(T) |
| 739 | +/// Returns the top n elements of the array in descending order. |
| 740 | +template <typename T> |
| 741 | +struct ArrayTopNFunction { |
| 742 | + VELOX_DEFINE_FUNCTION_TYPES(T); |
| 743 | + |
| 744 | + // Definition for primitives. |
| 745 | + template <typename TReturn, typename TInput> |
| 746 | + FOLLY_ALWAYS_INLINE void |
| 747 | + call(TReturn& result, const TInput& array, int32_t n) { |
| 748 | + VELOX_CHECK( |
| 749 | + n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n)); |
| 750 | + |
| 751 | + // If top n is zero or input array is empty then exit early. |
| 752 | + if (n == 0 || array.size() == 0) { |
| 753 | + return; |
| 754 | + } |
| 755 | + |
| 756 | + // Define comparator that wraps built-in function for basic primitives or |
| 757 | + // calls floating point handler for NaNs. |
| 758 | + using facebook::velox::util::floating_point::NaNAwareGreaterThan; |
| 759 | + struct GreaterThanComparator { |
| 760 | + bool operator()( |
| 761 | + const typename TInput::element_t& a, |
| 762 | + const typename TInput::element_t& b) const { |
| 763 | + if constexpr ( |
| 764 | + std::is_same_v<typename TInput::element_t, float> || |
| 765 | + std::is_same_v<typename TInput::element_t, double>) { |
| 766 | + return NaNAwareGreaterThan<typename TInput::element_t>{}(a, b); |
| 767 | + } else { |
| 768 | + return std::greater<typename TInput::element_t>{}(a, b); |
| 769 | + } |
| 770 | + } |
| 771 | + }; |
| 772 | + |
| 773 | + // Define min-heap to store the top n elements. |
| 774 | + std::priority_queue< |
| 775 | + typename TInput::element_t, |
| 776 | + std::vector<typename TInput::element_t>, |
| 777 | + GreaterThanComparator> |
| 778 | + minHeap; |
| 779 | + |
| 780 | + // Iterate through the array and push elements to the min-heap. |
| 781 | + GreaterThanComparator comparator; |
| 782 | + int numNull = 0; |
| 783 | + for (const auto& item : array) { |
| 784 | + if (item.has_value()) { |
| 785 | + if (minHeap.size() < n) { |
| 786 | + minHeap.push(item.value()); |
| 787 | + } else if (comparator(item.value(), minHeap.top())) { |
| 788 | + minHeap.push(item.value()); |
| 789 | + minHeap.pop(); |
| 790 | + } |
| 791 | + } else { |
| 792 | + ++numNull; |
| 793 | + } |
| 794 | + } |
| 795 | + |
| 796 | + // Reverse the min-heap to get the top n elements in descending order. |
| 797 | + std::vector<typename TInput::element_t> reversed(minHeap.size()); |
| 798 | + auto index = minHeap.size(); |
| 799 | + while (!minHeap.empty()) { |
| 800 | + reversed[--index] = minHeap.top(); |
| 801 | + minHeap.pop(); |
| 802 | + } |
| 803 | + |
| 804 | + // Copy mutated vector to result vector up to minHeap's size items. |
| 805 | + for (const auto& item : reversed) { |
| 806 | + result.push_back(item); |
| 807 | + } |
| 808 | + |
| 809 | + // Backfill nulls if needed. |
| 810 | + while (result.size() < n && numNull > 0) { |
| 811 | + result.add_null(); |
| 812 | + --numNull; |
| 813 | + } |
| 814 | + } |
| 815 | + |
| 816 | + // Generic implementation. |
| 817 | + FOLLY_ALWAYS_INLINE void call( |
| 818 | + out_type<Array<Orderable<T1>>>& result, |
| 819 | + const arg_type<Array<Orderable<T1>>>& array, |
| 820 | + const int32_t n) { |
| 821 | + VELOX_CHECK( |
| 822 | + n >= 0, fmt::format("Parameter n: {} to ARRAY_TOP_N is negative", n)); |
| 823 | + |
| 824 | + // If top n is zero or input array is empty then exit early. |
| 825 | + if (n == 0 || array.size() == 0) { |
| 826 | + return; |
| 827 | + } |
| 828 | + |
| 829 | + // Define comparator to compare complex types. |
| 830 | + struct ComplexTypeComparator { |
| 831 | + const arg_type<Array<Orderable<T1>>>& array; |
| 832 | + ComplexTypeComparator(const arg_type<Array<Orderable<T1>>>& array) |
| 833 | + : array(array) {} |
| 834 | + |
| 835 | + bool operator()(const int32_t& a, const int32_t& b) const { |
| 836 | + static constexpr CompareFlags kFlags = { |
| 837 | + .nullHandlingMode = |
| 838 | + CompareFlags::NullHandlingMode::kNullAsIndeterminate}; |
| 839 | + return array[a].value().compare(array[b].value(), kFlags).value() > 0; |
| 840 | + } |
| 841 | + }; |
| 842 | + |
| 843 | + // Define min-heap to store the top n elements. |
| 844 | + std::priority_queue<int32_t, std::vector<int32_t>, ComplexTypeComparator> |
| 845 | + minHeap(array); |
| 846 | + |
| 847 | + // Iterate through the array and push elements to the min-heap. |
| 848 | + ComplexTypeComparator comparator(array); |
| 849 | + int numNull = 0; |
| 850 | + for (int i = 0; i < array.size(); ++i) { |
| 851 | + if (array[i].has_value()) { |
| 852 | + if (minHeap.size() < n) { |
| 853 | + minHeap.push(i); |
| 854 | + } else if (comparator(i, minHeap.top())) { |
| 855 | + minHeap.push(i); |
| 856 | + minHeap.pop(); |
| 857 | + } |
| 858 | + } else { |
| 859 | + ++numNull; |
| 860 | + } |
| 861 | + } |
| 862 | + |
| 863 | + // Reverse the min-heap to get the top n elements in descending order. |
| 864 | + std::vector<int32_t> reversed(minHeap.size()); |
| 865 | + auto index = minHeap.size(); |
| 866 | + while (!minHeap.empty()) { |
| 867 | + reversed[--index] = minHeap.top(); |
| 868 | + minHeap.pop(); |
| 869 | + } |
| 870 | + |
| 871 | + // Copy mutated vector to result vector up to minHeap's size items. |
| 872 | + for (const auto& index : reversed) { |
| 873 | + result.push_back(array[index].value()); |
| 874 | + } |
| 875 | + |
| 876 | + // Backfill nulls if needed. |
| 877 | + while (result.size() < n && numNull > 0) { |
| 878 | + result.add_null(); |
| 879 | + --numNull; |
| 880 | + } |
| 881 | + } |
| 882 | +}; |
| 883 | + |
732 | 884 | template <typename T>
|
733 | 885 | struct ArrayTrimFunction {
|
734 | 886 | VELOX_DEFINE_FUNCTION_TYPES(T);
|
735 | 887 |
|
736 | 888 | // Fast path for primitives.
|
737 | 889 | template <typename Out, typename In>
|
738 |
| - void call(Out& out, const In& inputArray, int64_t size) { |
| 890 | + void call(Out& out, const In& inputArray, int32_t size) { |
739 | 891 | checkIndexArrayTrim(size, inputArray.size());
|
740 | 892 |
|
741 | 893 | int64_t end = inputArray.size() - size;
|
|
0 commit comments