diff --git a/velox/expression/SimpleFunctionRegistry.cpp b/velox/expression/SimpleFunctionRegistry.cpp index 34d746793286..9b08d6e18df5 100644 --- a/velox/expression/SimpleFunctionRegistry.cpp +++ b/velox/expression/SimpleFunctionRegistry.cpp @@ -68,6 +68,11 @@ bool SimpleFunctionRegistry::registerFunctionInternal( }); } +void SimpleFunctionRegistry::removeFunction(const std::string& name) { + const auto sanitizedName = sanitizeName(name); + registeredFunctions_.withWLock([&](auto& map) { map.erase(sanitizedName); }); +} + namespace { // This function is not thread safe. It should be called only from within a // synchronized read region of registeredFunctions_. diff --git a/velox/expression/SimpleFunctionRegistry.h b/velox/expression/SimpleFunctionRegistry.h index 61c77a23c3f0..ffd5ec5deb0d 100644 --- a/velox/expression/SimpleFunctionRegistry.h +++ b/velox/expression/SimpleFunctionRegistry.h @@ -91,6 +91,8 @@ class SimpleFunctionRegistry { } } + void removeFunction(const std::string& name); + std::vector getFunctionNames() const { std::vector result; registeredFunctions_.withRLock([&](const auto& map) { diff --git a/velox/functions/FunctionRegistry.cpp b/velox/functions/FunctionRegistry.cpp index 406d48280733..2ce84364a34a 100644 --- a/velox/functions/FunctionRegistry.cpp +++ b/velox/functions/FunctionRegistry.cpp @@ -181,4 +181,10 @@ resolveVectorFunctionWithMetadata( return exec::resolveVectorFunctionWithMetadata(functionName, argTypes); } +void removeFunction(const std::string& functionName) { + exec::mutableSimpleFunctions().removeFunction(functionName); + exec::vectorFunctionFactories().withWLock( + [&](auto& functionMap) { functionMap.erase(functionName); }); +} + } // namespace facebook::velox diff --git a/velox/functions/FunctionRegistry.h b/velox/functions/FunctionRegistry.h index b07cfb942130..706fe5d4f4c4 100644 --- a/velox/functions/FunctionRegistry.h +++ b/velox/functions/FunctionRegistry.h @@ -97,6 +97,10 @@ resolveVectorFunctionWithMetadata( const std::string& functionName, const std::vector& argTypes); +/// Given name of a function, removes it from both the simple and vector +/// function registries (including all signatures). +void removeFunction(const std::string& functionName); + /// Clears the function registry. void clearFunctionRegistry(); diff --git a/velox/functions/tests/FunctionRegistryTest.cpp b/velox/functions/tests/FunctionRegistryTest.cpp index fd34a8fd4d50..e94d849bbb1a 100644 --- a/velox/functions/tests/FunctionRegistryTest.cpp +++ b/velox/functions/tests/FunctionRegistryTest.cpp @@ -81,6 +81,10 @@ inline void registerTestFunctions() { VELOX_REGISTER_VECTOR_FUNCTION(udf_vector_func_three, "vector_func_three"); VELOX_REGISTER_VECTOR_FUNCTION(udf_vector_func_four, "vector_func_four"); } + +inline void registerTestVectorFunctionOne(const std::string& functionName) { + VELOX_REGISTER_VECTOR_FUNCTION(udf_vector_func_one, functionName); +} } // namespace class FunctionRegistryTest : public testing::Test { @@ -107,6 +111,44 @@ class FunctionRegistryTest : public testing::Test { } }; +TEST_F(FunctionRegistryTest, removeFunction) { + const std::string functionName = "func_to_remove"; + auto checkFunctionExists = [&](const std::string& name, + bool vectorFuncSignatures, + bool simpleFuncSignatures) { + EXPECT_EQ( + getFunctionSignatures(name).size(), + vectorFuncSignatures + simpleFuncSignatures); + EXPECT_EQ(getVectorFunctionSignatures().count(name), vectorFuncSignatures); + EXPECT_EQ( + exec::simpleFunctions().getFunctionSignatures(name).size(), + simpleFuncSignatures); + }; + + checkFunctionExists(functionName, 0, 0); + + // Only vector function registered + registerTestVectorFunctionOne(functionName); + checkFunctionExists(functionName, 1, 0); + removeFunction(functionName); + checkFunctionExists(functionName, 0, 0); + + // Only simple function registered + registerFunction( + std::vector{functionName}); + checkFunctionExists(functionName, 0, 1); + removeFunction(functionName); + checkFunctionExists(functionName, 0, 0); + + // Both vector and simple function registered + registerTestVectorFunctionOne(functionName); + registerFunction( + std::vector{functionName}); + checkFunctionExists(functionName, 1, 1); + removeFunction(functionName); + checkFunctionExists(functionName, 0, 0); +} + TEST_F(FunctionRegistryTest, getFunctionSignaturesByName) { { auto signatures = getFunctionSignatures("func_one");