Skip to content

Commit 0275a64

Browse files
Vinti Pandeyfacebook-github-bot
Vinti Pandey
authored andcommitted
Add an API to fetch all available aggregate functions (facebookincubator#2650)
Summary: Pull Request resolved: facebookincubator#2650 Add getAggregateFunctionSignatures() API to return all registered aggregate functions. For each function return a name and a list of signatures. Reviewed By: mbasmanova Differential Revision: D39713014 fbshipit-source-id: 16aecd6790b0a62f76439db413d23d06db3da5d4
1 parent df538f2 commit 0275a64

File tree

3 files changed

+66
-21
lines changed

3 files changed

+66
-21
lines changed

velox/exec/Aggregate.cpp

+15
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,21 @@ bool registerAggregateFunction(
6161
return true;
6262
}
6363

64+
std::unordered_map<
65+
std::string,
66+
std::vector<std::shared_ptr<AggregateFunctionSignature>>>
67+
getAggregateFunctionSignatures() {
68+
std::unordered_map<
69+
std::string,
70+
std::vector<std::shared_ptr<AggregateFunctionSignature>>>
71+
map;
72+
auto aggregateFunctions = exec::aggregateFunctions();
73+
for (const auto& aggregateFunction : aggregateFunctions) {
74+
map[aggregateFunction.first] = aggregateFunction.second.signatures;
75+
}
76+
return map;
77+
}
78+
6479
std::optional<std::vector<std::shared_ptr<AggregateFunctionSignature>>>
6580
getAggregateFunctionSignatures(const std::string& name) {
6681
if (auto func = getAggregateFunctionEntry(name)) {

velox/exec/Aggregate.h

+7
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,13 @@ bool registerAggregateFunction(
329329
std::optional<std::vector<std::shared_ptr<AggregateFunctionSignature>>>
330330
getAggregateFunctionSignatures(const std::string& name);
331331

332+
/// Returns a mapping of all Aggregate functions in registry.
333+
/// The mapping is function name -> list of function signatures.
334+
std::unordered_map<
335+
std::string,
336+
std::vector<std::shared_ptr<AggregateFunctionSignature>>>
337+
getAggregateFunctionSignatures();
338+
332339
struct AggregateFunctionEntry {
333340
std::vector<std::shared_ptr<AggregateFunctionSignature>> signatures;
334341
AggregateFunctionFactory factory;

velox/exec/tests/AggregateFunctionRegistryTest.cpp

+44-21
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
* limitations under the License.
1515
*/
1616

17+
#include "velox/exec/AggregateFunctionRegistry.h"
1718
#include <gtest/gtest.h>
18-
1919
#include "velox/exec/Aggregate.h"
20-
#include "velox/exec/AggregateFunctionRegistry.h"
2120
#include "velox/functions/Registerer.h"
2221
#include "velox/type/Type.h"
2322

@@ -72,28 +71,32 @@ class AggregateFunc : public Aggregate {
7271
char** /*groups*/,
7372
int32_t /*numGroups*/,
7473
VectorPtr* /*result*/) override {}
74+
static std::vector<std::shared_ptr<AggregateFunctionSignature>> signatures() {
75+
std::vector<std::shared_ptr<AggregateFunctionSignature>> signatures{
76+
AggregateFunctionSignatureBuilder()
77+
.returnType("bigint")
78+
.intermediateType("array(bigint)")
79+
.argumentType("bigint")
80+
.argumentType("double")
81+
.build(),
82+
AggregateFunctionSignatureBuilder()
83+
.typeVariable("T")
84+
.returnType("T")
85+
.intermediateType("array(T)")
86+
.argumentType("T")
87+
.argumentType("T")
88+
.build(),
89+
AggregateFunctionSignatureBuilder()
90+
.returnType("date")
91+
.intermediateType("date")
92+
.build(),
93+
};
94+
return signatures;
95+
}
7596
};
7697

7798
bool registerAggregateFunc(const std::string& name) {
78-
std::vector<std::shared_ptr<AggregateFunctionSignature>> signatures{
79-
AggregateFunctionSignatureBuilder()
80-
.returnType("bigint")
81-
.intermediateType("array(bigint)")
82-
.argumentType("bigint")
83-
.argumentType("double")
84-
.build(),
85-
AggregateFunctionSignatureBuilder()
86-
.typeVariable("T")
87-
.returnType("T")
88-
.intermediateType("array(T)")
89-
.argumentType("T")
90-
.argumentType("T")
91-
.build(),
92-
AggregateFunctionSignatureBuilder()
93-
.returnType("date")
94-
.intermediateType("date")
95-
.build(),
96-
};
99+
auto signatures = AggregateFunc::signatures();
97100

98101
registerAggregateFunction(
99102
name,
@@ -175,4 +178,24 @@ TEST_F(FunctionRegistryTest, functionNameInMixedCase) {
175178
"aggregatE_funC_aliaS", {DOUBLE(), DOUBLE()}, DOUBLE(), ARRAY(DOUBLE()));
176179
}
177180

181+
TEST_F(FunctionRegistryTest, getAggregateFunctionSignatures) {
182+
auto functionSignatures = getAggregateFunctionSignatures();
183+
auto aggregateFuncSignatures = functionSignatures["aggregate_func"];
184+
std::vector<std::string> aggregateFuncSignaturesStr;
185+
std::transform(
186+
aggregateFuncSignatures.begin(),
187+
aggregateFuncSignatures.end(),
188+
std::back_inserter(aggregateFuncSignaturesStr),
189+
[](auto& signature) { return signature->toString(); });
190+
191+
auto expectedSignatures = AggregateFunc::signatures();
192+
std::vector<std::string> expectedSignaturesStr;
193+
std::transform(
194+
expectedSignatures.begin(),
195+
expectedSignatures.end(),
196+
std::back_inserter(expectedSignaturesStr),
197+
[](auto& signature) { return signature->toString(); });
198+
199+
ASSERT_EQ(aggregateFuncSignaturesStr, expectedSignaturesStr);
200+
}
178201
} // namespace facebook::velox::exec::test

0 commit comments

Comments
 (0)