From bfeb1891fabd04e1158d77a8a8b85dbc0cb7b7c2 Mon Sep 17 00:00:00 2001 From: Natasha Sehgal Date: Tue, 4 Mar 2025 16:24:32 -0800 Subject: [PATCH] Support TDigestType in Velox Functions (#12326) Summary: Pull Request resolved: https://github.com/facebookincubator/velox/pull/12326 X-link: https://github.com/prestodb/presto/pull/24546 Add support for TDIGEST(DOUBLE) in Prestissimo along with additional test cases Reviewed By: Yuhta Differential Revision: D69558489 fbshipit-source-id: 3d911759af5d80fde4b7653d66fa44a623a52274 --- velox/core/tests/ConstantTypedExprTest.cpp | 4 ++++ velox/expression/tests/CustomTypeTest.cpp | 22 +++++++++++------ .../prestosql/registration/CMakeLists.txt | 1 + .../registration/RegistrationFunctions.cpp | 6 +++++ .../registration/RegistrationFunctions.h | 2 ++ .../TDigestFunctionsRegistration.cpp | 24 +++++++++++++++++++ velox/vector/tests/VectorSaverTest.cpp | 4 ++++ 7 files changed, 56 insertions(+), 7 deletions(-) create mode 100644 velox/functions/prestosql/registration/TDigestFunctionsRegistration.cpp diff --git a/velox/core/tests/ConstantTypedExprTest.cpp b/velox/core/tests/ConstantTypedExprTest.cpp index cba32ed7da05..3d067f9ede87 100644 --- a/velox/core/tests/ConstantTypedExprTest.cpp +++ b/velox/core/tests/ConstantTypedExprTest.cpp @@ -18,6 +18,7 @@ #include "velox/core/Expressions.h" #include "velox/functions/prestosql/types/HyperLogLogType.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/prestosql/types/TDigestType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" namespace facebook::velox::core::test { @@ -47,6 +48,9 @@ TEST(ConstantTypedExprTest, null) { EXPECT_FALSE(*makeNull(HYPERLOGLOG()) == *makeNull(VARBINARY())); EXPECT_FALSE(*makeNull(VARBINARY()) == *makeNull(HYPERLOGLOG())); + EXPECT_FALSE(*makeNull(TDIGEST(DOUBLE())) == *makeNull(VARBINARY())); + EXPECT_FALSE(*makeNull(VARBINARY()) == *makeNull(TDIGEST(DOUBLE()))); + EXPECT_FALSE(*makeNull(TIMESTAMP_WITH_TIME_ZONE()) == *makeNull(BIGINT())); EXPECT_FALSE(*makeNull(BIGINT()) == *makeNull(TIMESTAMP_WITH_TIME_ZONE())); diff --git a/velox/expression/tests/CustomTypeTest.cpp b/velox/expression/tests/CustomTypeTest.cpp index 6e5d3df7cb53..ff3e42883aab 100644 --- a/velox/expression/tests/CustomTypeTest.cpp +++ b/velox/expression/tests/CustomTypeTest.cpp @@ -231,7 +231,7 @@ TEST_F(CustomTypeTest, getCustomTypeNames) { "IPADDRESS", "IPPREFIX", "BINGTILE", - }), + "TDIGEST"}), names); ASSERT_TRUE(registerCustomType( @@ -248,7 +248,7 @@ TEST_F(CustomTypeTest, getCustomTypeNames) { "IPPREFIX", "BINGTILE", "FANCY_INT", - }), + "TDIGEST"}), names); ASSERT_TRUE(unregisterCustomType("fancy_int")); @@ -257,19 +257,27 @@ TEST_F(CustomTypeTest, getCustomTypeNames) { TEST_F(CustomTypeTest, nullConstant) { ASSERT_TRUE(registerCustomType( "fancy_int", std::make_unique())); - - auto names = getCustomTypeNames(); - for (const auto& name : names) { - auto type = getCustomType(name, {}); + auto checkNullConstant = [&](const TypePtr& type, + const std::string& expectedTypeString) { auto null = BaseVector::createNullConstant(type, 10, pool()); EXPECT_TRUE(null->isConstantEncoding()); EXPECT_TRUE(type->equivalent(*null->type())); EXPECT_EQ(type->toString(), null->type()->toString()); + EXPECT_EQ(type->toString(), expectedTypeString); for (auto i = 0; i < 10; ++i) { EXPECT_TRUE(null->isNullAt(i)); } + }; + auto names = getCustomTypeNames(); + for (const auto& name : names) { + if (name == "TDIGEST") { + auto type = getCustomType(name, {TypeParameter(DOUBLE())}); + checkNullConstant(type, "TDIGEST(DOUBLE)"); + } else { + auto type = getCustomType(name, {}); + checkNullConstant(type, type->toString()); + } } - ASSERT_TRUE(unregisterCustomType("fancy_int")); } diff --git a/velox/functions/prestosql/registration/CMakeLists.txt b/velox/functions/prestosql/registration/CMakeLists.txt index 4e0f4a12c475..1cec4506265c 100644 --- a/velox/functions/prestosql/registration/CMakeLists.txt +++ b/velox/functions/prestosql/registration/CMakeLists.txt @@ -33,6 +33,7 @@ velox_add_library( ProbabilityTrigonometricFunctionsRegistration.cpp RegistrationFunctions.cpp StringFunctionsRegistration.cpp + TDigestFunctionsRegistration.cpp URLFunctionsRegistration.cpp) # GCC 12 has a bug where it does not respect "pragma ignore" directives and ends diff --git a/velox/functions/prestosql/registration/RegistrationFunctions.cpp b/velox/functions/prestosql/registration/RegistrationFunctions.cpp index ae24c5bd17c2..3e84b0902e60 100644 --- a/velox/functions/prestosql/registration/RegistrationFunctions.cpp +++ b/velox/functions/prestosql/registration/RegistrationFunctions.cpp @@ -29,6 +29,7 @@ extern void registerComparisonFunctions(const std::string& prefix); extern void registerDateTimeFunctions(const std::string& prefix); extern void registerGeneralFunctions(const std::string& prefix); extern void registerHyperLogFunctions(const std::string& prefix); +extern void registerTDigestFunctions(const std::string& prefix); extern void registerIntegerFunctions(const std::string& prefix); extern void registerJsonFunctions(const std::string& prefix); extern void registerMapFunctions(const std::string& prefix); @@ -72,6 +73,10 @@ void registerHyperLogFunctions(const std::string& prefix) { functions::registerHyperLogFunctions(prefix); } +void registerTDigestFunctions(const std::string& prefix) { + functions::registerTDigestFunctions(prefix); +} + void registerIntegerFunctions(const std::string& prefix) { functions::registerIntegerFunctions(prefix); } @@ -112,6 +117,7 @@ void registerAllScalarFunctions(const std::string& prefix) { registerArrayFunctions(prefix); registerJsonFunctions(prefix); registerHyperLogFunctions(prefix); + registerTDigestFunctions(prefix); registerIntegerFunctions(prefix); registerGeospatialFunctions(prefix); registerGeneralFunctions(prefix); diff --git a/velox/functions/prestosql/registration/RegistrationFunctions.h b/velox/functions/prestosql/registration/RegistrationFunctions.h index 8dccbe168f7f..2f2f6a62b6ab 100644 --- a/velox/functions/prestosql/registration/RegistrationFunctions.h +++ b/velox/functions/prestosql/registration/RegistrationFunctions.h @@ -33,6 +33,8 @@ void registerJsonFunctions(const std::string& prefix = ""); void registerHyperLogFunctions(const std::string& prefix = ""); +void registerTDigestFunctions(const std::string& prefix = ""); + void registerGeospatialFunctions(const std::string& prefix = ""); void registerGeneralFunctions(const std::string& prefix = ""); diff --git a/velox/functions/prestosql/registration/TDigestFunctionsRegistration.cpp b/velox/functions/prestosql/registration/TDigestFunctionsRegistration.cpp new file mode 100644 index 000000000000..1c5b27656278 --- /dev/null +++ b/velox/functions/prestosql/registration/TDigestFunctionsRegistration.cpp @@ -0,0 +1,24 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "velox/functions/Registerer.h" +#include "velox/functions/prestosql/types/TDigestRegistration.h" + +namespace facebook::velox::functions { + +void registerTDigestFunctions(const std::string& prefix) { + registerTDigestType(); +} +} // namespace facebook::velox::functions diff --git a/velox/vector/tests/VectorSaverTest.cpp b/velox/vector/tests/VectorSaverTest.cpp index 8ce086c7c62f..66579cfb3d0e 100644 --- a/velox/vector/tests/VectorSaverTest.cpp +++ b/velox/vector/tests/VectorSaverTest.cpp @@ -23,6 +23,8 @@ #include "velox/functions/prestosql/types/HyperLogLogType.h" #include "velox/functions/prestosql/types/JsonRegistration.h" #include "velox/functions/prestosql/types/JsonType.h" +#include "velox/functions/prestosql/types/TDigestRegistration.h" +#include "velox/functions/prestosql/types/TDigestType.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneRegistration.h" #include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h" #include "velox/vector/fuzzer/VectorFuzzer.h" @@ -40,6 +42,7 @@ class VectorSaverTest : public testing::Test, public VectorTestBase { registerJsonType(); registerHyperLogLogType(); registerTimestampWithTimeZoneType(); + registerTDigestType(); } void SetUp() override { @@ -268,6 +271,7 @@ TEST_F(VectorSaverTest, types) { testTypeRoundTrip(JSON()); testTypeRoundTrip(HYPERLOGLOG()); testTypeRoundTrip(TIMESTAMP_WITH_TIME_ZONE()); + testTypeRoundTrip(TDIGEST(DOUBLE())); } TEST_F(VectorSaverTest, selectivityVector) {