Skip to content

Commit 58ff291

Browse files
natashasehgalfacebook-github-bot
authored andcommitted
TDigest Value at Quantile Functions (facebookincubator#12529)
Summary: Pull Request resolved: facebookincubator#12529 Add value_at_quantile , values_at_quantile functions Reviewed By: Yuhta Differential Revision: D70256593 fbshipit-source-id: 6860fe1f5d86caef5a9163dc09dd2c834279e9f5
1 parent ee23d0e commit 58ff291

File tree

7 files changed

+171
-9
lines changed

7 files changed

+171
-9
lines changed

velox/exec/fuzzer/PrestoQueryRunner.cpp

+7-6
Original file line numberDiff line numberDiff line change
@@ -451,16 +451,17 @@ bool PrestoQueryRunner::isConstantExprSupported(
451451
bool PrestoQueryRunner::isSupported(const exec::FunctionSignature& signature) {
452452
// TODO: support queries with these types. Among the types below, hugeint is
453453
// not a native type in Presto, so fuzzer should not use it as the type of
454-
// cast-to or constant literals. Hyperloglog can only be casted from varbinary
455-
// and cannot be used as the type of constant literals. Interval year to month
456-
// can only be casted from NULL and cannot be used as the type of constant
457-
// literals. Json, Ipaddress, Ipprefix, and UUID require special handling,
458-
// because Presto requires literals of these types to be valid, and doesn't
459-
// allow creating HIVE columns of these types.
454+
// cast-to or constant literals. Hyperloglog and TDigest can only be casted
455+
// from varbinary and cannot be used as the type of constant literals.
456+
// Interval year to month can only be casted from NULL and cannot be used as
457+
// the type of constant literals. Json, Ipaddress, Ipprefix, and UUID require
458+
// special handling, because Presto requires literals of these types to be
459+
// valid, and doesn't allow creating HIVE columns of these types.
460460
return !(
461461
usesTypeName(signature, "interval year to month") ||
462462
usesTypeName(signature, "hugeint") ||
463463
usesTypeName(signature, "hyperloglog") ||
464+
usesTypeName(signature, "tdigest") ||
464465
usesInputTypeName(signature, "ipaddress") ||
465466
usesInputTypeName(signature, "ipprefix") ||
466467
usesInputTypeName(signature, "uuid"));

velox/exec/tests/FunctionResolutionTest.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "velox/functions/prestosql/types/BingTileType.h"
2525
#include "velox/functions/prestosql/types/HyperLogLogType.h"
2626
#include "velox/functions/prestosql/types/JsonType.h"
27+
#include "velox/functions/prestosql/types/TDigestType.h"
2728
#include "velox/functions/prestosql/types/TimestampWithTimeZoneType.h"
2829

2930
namespace {
@@ -289,6 +290,21 @@ TEST_F(FunctionResolutionTest, resolveCustomTypeHyperLogLog) {
289290
EXPECT_EQ(type->toString(), HYPERLOGLOG()->toString());
290291
}
291292

293+
template <typename T>
294+
struct FuncTDigest {
295+
VELOX_DEFINE_FUNCTION_TYPES(T);
296+
bool call(out_type<SimpleTDigest<double>>&) {
297+
return false;
298+
}
299+
};
300+
301+
TEST_F(FunctionResolutionTest, resolveCustomTypeTDigest) {
302+
registerFunction<FuncTDigest, SimpleTDigest<double>>({"f_tdigest"});
303+
304+
auto type = exec::simpleFunctions().resolveFunction("f_tdigest", {})->type();
305+
EXPECT_EQ(type->toString(), TDIGEST(DOUBLE())->toString());
306+
}
307+
292308
template <typename T>
293309
struct FuncJson {
294310
VELOX_DEFINE_FUNCTION_TYPES(T);

velox/expression/fuzzer/ExpressionFuzzerTest.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ int main(int argc, char** argv) {
8585
"cardinality",
8686
"element_at",
8787
"width_bucket",
88+
// Fuzzer and the underlying engine are confused about TDigest functions
89+
// (since TDigest is a user defined type), and tries to pass a
90+
// VARBINARY (since TDigest's implementation uses an
91+
// alias to VARBINARY).
92+
"value_at_quantile",
93+
"values_at_quantiles",
8894
// Fuzzer cannot generate valid 'comparator' lambda.
8995
"array_sort(array(T),constant function(T,T,bigint)) -> array(T)",
9096
"split_to_map(varchar,varchar,varchar,function(varchar,varchar,varchar,varchar)) -> map(varchar,varchar)",

velox/functions/lib/TDigest.h

+5-1
Original file line numberDiff line numberDiff line change
@@ -447,11 +447,15 @@ void TDigest<A>::mergeDeserialized(
447447
tdigest::detail::read(input, sum);
448448
}
449449
tdigest::detail::read(input, compression);
450-
VELOX_CHECK_EQ(compression, compression_);
450+
// If the TDigest is empty, set compression from TDigest being merged.
451+
if (weights_.empty()) {
452+
setCompression(compression);
453+
}
451454
tdigest::detail::read(input, totalWeight);
452455
int32_t numNew;
453456
tdigest::detail::read(input, numNew);
454457
if (numNew > 0) {
458+
VELOX_CHECK_EQ(compression, compression_);
455459
auto numOld = weights_.size();
456460
weights_.resize(numOld + numNew);
457461
auto* weights = weights_.data() + numOld;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
#include "velox/functions/Macros.h"
20+
#include "velox/functions/lib/TDigest.h"
21+
#include "velox/functions/prestosql/types/TDigestType.h"
22+
23+
namespace facebook::velox::functions {
24+
25+
template <typename T>
26+
struct ValueAtQuantileFunction {
27+
VELOX_DEFINE_FUNCTION_TYPES(T);
28+
FOLLY_ALWAYS_INLINE void call(
29+
out_type<double>& result,
30+
const arg_type<SimpleTDigest<double>>& input,
31+
const arg_type<double>& quantile) {
32+
TDigest<> digest;
33+
std::vector<int16_t> positions;
34+
digest.mergeDeserialized(positions, input.data());
35+
digest.compress(positions);
36+
result = digest.estimateQuantile(quantile);
37+
}
38+
};
39+
40+
template <typename T>
41+
struct ValuesAtQuantilesFunction {
42+
VELOX_DEFINE_FUNCTION_TYPES(T);
43+
44+
FOLLY_ALWAYS_INLINE void call(
45+
out_type<Array<double>>& result,
46+
const arg_type<SimpleTDigest<double>>& input,
47+
const arg_type<Array<double>>& quantiles) {
48+
TDigest<> digest;
49+
std::vector<int16_t> positions;
50+
digest.mergeDeserialized(positions, input.data());
51+
digest.compress(positions);
52+
result.resize(quantiles.size());
53+
for (size_t i = 0; i < quantiles.size(); ++i) {
54+
result[i] = digest.estimateQuantile(quantiles[i].value());
55+
}
56+
}
57+
};
58+
59+
} // namespace facebook::velox::functions

velox/functions/prestosql/registration/TDigestFunctionsRegistration.cpp

+13-2
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,22 @@
1414
* limitations under the License.
1515
*/
1616
#include "velox/functions/Registerer.h"
17+
#include "velox/functions/prestosql/TDigestFunctions.h"
1718
#include "velox/functions/prestosql/types/TDigestRegistration.h"
18-
19+
#include "velox/functions/prestosql/types/TDigestType.h"
1920
namespace facebook::velox::functions {
2021

2122
void registerTDigestFunctions(const std::string& prefix) {
22-
registerTDigestType();
23+
facebook::velox::registerTDigestType();
24+
registerFunction<
25+
ValueAtQuantileFunction,
26+
double,
27+
SimpleTDigest<double>,
28+
double>({prefix + "value_at_quantile"});
29+
registerFunction<
30+
ValuesAtQuantilesFunction,
31+
Array<double>,
32+
SimpleTDigest<double>,
33+
Array<double>>({prefix + "values_at_quantiles"});
2334
}
2435
} // namespace facebook::velox::functions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
#include <folly/base64.h>
17+
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"
18+
#include "velox/functions/prestosql/types/TDigestRegistration.h"
19+
#include "velox/functions/prestosql/types/TDigestType.h"
20+
21+
using namespace facebook::velox;
22+
using namespace facebook::velox::exec;
23+
using namespace facebook::velox::functions::test;
24+
25+
class TDigestFunctionsTest : public FunctionBaseTest {
26+
protected:
27+
void SetUp() override {
28+
FunctionBaseTest::SetUp();
29+
registerTDigestType();
30+
}
31+
32+
protected:
33+
std::string decodeBase64(std::string_view input) {
34+
std::string decoded(folly::base64DecodedSize(input), '\0');
35+
folly::base64Decode(input, decoded.data());
36+
return decoded;
37+
}
38+
};
39+
40+
TEST_F(TDigestFunctionsTest, valueAtQuantile) {
41+
const TypePtr type = TDIGEST(DOUBLE());
42+
const auto valueAtQuantile = [&](const std::optional<std::string>& input,
43+
const std::optional<double>& quantile) {
44+
return evaluateOnce<double>(
45+
"value_at_quantile(c0, c1)", type, input, quantile);
46+
};
47+
const std::string input = decodeBase64(
48+
"AQAAAAAAAADwPwAAAAAAABRAAAAAAAAALkAAAAAAAABZQAAAAAAAABRABQAAAAAAAAAAAPA/AAAAAAAA8D8AAAAAAADwPwAAAAAAAPA/AAAAAAAA8D8AAAAAAADwPwAAAAAAAABAAAAAAAAACEAAAAAAAAAQQAAAAAAAABRA");
49+
ASSERT_EQ(1.0, valueAtQuantile(input, 0.1));
50+
ASSERT_EQ(3.0, valueAtQuantile(input, 0.5));
51+
ASSERT_EQ(5.0, valueAtQuantile(input, 0.9));
52+
ASSERT_EQ(5.0, valueAtQuantile(input, 0.99));
53+
};
54+
55+
TEST_F(TDigestFunctionsTest, valuesAtQuantiles) {
56+
const TypePtr type = TDIGEST(DOUBLE());
57+
const std::string input = decodeBase64(
58+
"AQAAAAAAAADwPwAAAAAAABRAAAAAAAAALkAAAAAAAABZQAAAAAAAABRABQAAAAAAAAAAAPA/AAAAAAAA8D8AAAAAAADwPwAAAAAAAPA/AAAAAAAA8D8AAAAAAADwPwAAAAAAAABAAAAAAAAACEAAAAAAAAAQQAAAAAAAABRA");
59+
VectorPtr arg0 = makeFlatVector<std::string>({input}, type);
60+
VectorPtr arg1 = makeNullableArrayVector<double>({{0.1, 0.5, 0.9, 0.99}});
61+
auto expected = makeNullableArrayVector<double>({{1.0, 3.0, 5.0, 5.0}});
62+
auto result =
63+
evaluate("values_at_quantiles(c0, c1)", makeRowVector({arg0, arg1}));
64+
test::assertEqualVectors(expected, result);
65+
}

0 commit comments

Comments
 (0)