|
14 | 14 | * limitations under the License.
|
15 | 15 | */
|
16 | 16 |
|
| 17 | +#include "velox/exec/AggregateFunctionRegistry.h" |
17 | 18 | #include <gtest/gtest.h>
|
18 |
| - |
19 | 19 | #include "velox/exec/Aggregate.h"
|
20 |
| -#include "velox/exec/AggregateFunctionRegistry.h" |
21 | 20 | #include "velox/functions/Registerer.h"
|
22 | 21 | #include "velox/type/Type.h"
|
23 | 22 |
|
@@ -72,28 +71,32 @@ class AggregateFunc : public Aggregate {
|
72 | 71 | char** /*groups*/,
|
73 | 72 | int32_t /*numGroups*/,
|
74 | 73 | VectorPtr* /*result*/) override {}
|
| 74 | + static std::vector<std::shared_ptr<AggregateFunctionSignature>> signatures() { |
| 75 | + std::vector<std::shared_ptr<AggregateFunctionSignature>> signatures{ |
| 76 | + AggregateFunctionSignatureBuilder() |
| 77 | + .returnType("bigint") |
| 78 | + .intermediateType("array(bigint)") |
| 79 | + .argumentType("bigint") |
| 80 | + .argumentType("double") |
| 81 | + .build(), |
| 82 | + AggregateFunctionSignatureBuilder() |
| 83 | + .typeVariable("T") |
| 84 | + .returnType("T") |
| 85 | + .intermediateType("array(T)") |
| 86 | + .argumentType("T") |
| 87 | + .argumentType("T") |
| 88 | + .build(), |
| 89 | + AggregateFunctionSignatureBuilder() |
| 90 | + .returnType("date") |
| 91 | + .intermediateType("date") |
| 92 | + .build(), |
| 93 | + }; |
| 94 | + return signatures; |
| 95 | + } |
75 | 96 | };
|
76 | 97 |
|
77 | 98 | bool registerAggregateFunc(const std::string& name) {
|
78 |
| - std::vector<std::shared_ptr<AggregateFunctionSignature>> signatures{ |
79 |
| - AggregateFunctionSignatureBuilder() |
80 |
| - .returnType("bigint") |
81 |
| - .intermediateType("array(bigint)") |
82 |
| - .argumentType("bigint") |
83 |
| - .argumentType("double") |
84 |
| - .build(), |
85 |
| - AggregateFunctionSignatureBuilder() |
86 |
| - .typeVariable("T") |
87 |
| - .returnType("T") |
88 |
| - .intermediateType("array(T)") |
89 |
| - .argumentType("T") |
90 |
| - .argumentType("T") |
91 |
| - .build(), |
92 |
| - AggregateFunctionSignatureBuilder() |
93 |
| - .returnType("date") |
94 |
| - .intermediateType("date") |
95 |
| - .build(), |
96 |
| - }; |
| 99 | + auto signatures = AggregateFunc::signatures(); |
97 | 100 |
|
98 | 101 | registerAggregateFunction(
|
99 | 102 | name,
|
@@ -175,4 +178,24 @@ TEST_F(FunctionRegistryTest, functionNameInMixedCase) {
|
175 | 178 | "aggregatE_funC_aliaS", {DOUBLE(), DOUBLE()}, DOUBLE(), ARRAY(DOUBLE()));
|
176 | 179 | }
|
177 | 180 |
|
| 181 | +TEST_F(FunctionRegistryTest, getAggregateFunctionSignatures) { |
| 182 | + auto functionSignatures = getAggregateFunctionSignatures(); |
| 183 | + auto aggregateFuncSignatures = functionSignatures["aggregate_func"]; |
| 184 | + std::vector<std::string> aggregateFuncSignaturesStr; |
| 185 | + std::transform( |
| 186 | + aggregateFuncSignatures.begin(), |
| 187 | + aggregateFuncSignatures.end(), |
| 188 | + std::back_inserter(aggregateFuncSignaturesStr), |
| 189 | + [](auto& signature) { return signature->toString(); }); |
| 190 | + |
| 191 | + auto expectedSignatures = AggregateFunc::signatures(); |
| 192 | + std::vector<std::string> expectedSignaturesStr; |
| 193 | + std::transform( |
| 194 | + expectedSignatures.begin(), |
| 195 | + expectedSignatures.end(), |
| 196 | + std::back_inserter(expectedSignaturesStr), |
| 197 | + [](auto& signature) { return signature->toString(); }); |
| 198 | + |
| 199 | + ASSERT_EQ(aggregateFuncSignaturesStr, expectedSignaturesStr); |
| 200 | +} |
178 | 201 | } // namespace facebook::velox::exec::test
|
0 commit comments