Skip to content

Commit 59dac08

Browse files
peterenescufacebook-github-bot
authored andcommitted
feat(aggregate): Add complex type to map_union_sum (facebookincubator#12268)
Summary: Adds complex type to Presto function map_union_sum. Addition required some additional surgery in order to make primitives/strings accumulator forward compatible with ComplexType accumulator, namely functions extract/addValues. Differential Revision: D69204449
1 parent 04bfdff commit 59dac08

File tree

2 files changed

+198
-63
lines changed

2 files changed

+198
-63
lines changed

velox/functions/prestosql/aggregates/MapUnionSumAggregate.cpp

+148-63
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
* See the License for the specific language governing permissions and
1414
* limitations under the License.
1515
*/
16+
#include "velox/exec/AddressableNonNullValueList.h"
1617
#include "velox/exec/Aggregate.h"
1718
#include "velox/exec/Strings.h"
1819
#include "velox/expression/FunctionSignature.h"
19-
#include "velox/functions/lib/CheckedArithmeticImpl.h"
2020
#include "velox/functions/prestosql/aggregates/AggregateNames.h"
2121
#include "velox/vector/FlatVector.h"
2222

@@ -32,7 +32,7 @@ struct Accumulator {
3232
AlignedStlAllocator<std::pair<const K, S>, 16>>::Type;
3333
ValuesMap sums;
3434

35-
explicit Accumulator(HashStringAllocator* allocator)
35+
explicit Accumulator(const TypePtr& /*type*/, HashStringAllocator* allocator)
3636
: sums{AlignedStlAllocator<std::pair<const K, S>, 16>(allocator)} {}
3737

3838
size_t size() const {
@@ -41,18 +41,20 @@ struct Accumulator {
4141

4242
void addValues(
4343
const MapVector* mapVector,
44-
const SimpleVector<K>* mapKeys,
45-
const SimpleVector<S>* mapValues,
44+
const VectorPtr& mapKeys,
45+
const VectorPtr& mapValues,
4646
vector_size_t row,
4747
HashStringAllocator* allocator) {
48+
auto keys = mapKeys->template as<SimpleVector<K>>();
49+
auto values = mapValues->template as<SimpleVector<S>>();
4850
auto offset = mapVector->offsetAt(row);
4951
auto size = mapVector->sizeAt(row);
5052

5153
for (auto i = 0; i < size; ++i) {
5254
// Ignore null map keys.
53-
if (!mapKeys->isNullAt(offset + i)) {
54-
auto key = mapKeys->valueAt(offset + i);
55-
addValue(key, mapValues, offset + i, mapValues->typeKind());
55+
if (!keys->isNullAt(offset + i)) {
56+
auto key = keys->valueAt(offset + i);
57+
addValue(key, values, offset + i, values->typeKind());
5658
}
5759
}
5860
}
@@ -94,13 +96,16 @@ struct Accumulator {
9496
}
9597

9698
vector_size_t extractValues(
97-
FlatVector<K>& mapKeys,
98-
FlatVector<S>& mapValues,
99+
VectorPtr& mapKeys,
100+
VectorPtr& mapValues,
99101
vector_size_t offset) {
102+
auto keys = mapKeys->asFlatVector<K>();
103+
auto values = mapValues->asFlatVector<S>();
104+
100105
auto index = offset;
101106
for (const auto& [key, sum] : sums) {
102-
mapKeys.set(index, key);
103-
mapValues.set(index, sum);
107+
keys->set(index, key);
108+
values->set(index, sum);
104109

105110
++index;
106111
}
@@ -115,26 +120,30 @@ struct StringViewAccumulator {
115120

116121
Strings strings;
117122

118-
explicit StringViewAccumulator(HashStringAllocator* allocator)
119-
: base{allocator} {}
123+
explicit StringViewAccumulator(
124+
const TypePtr& type,
125+
HashStringAllocator* allocator)
126+
: base{type, allocator} {}
120127

121128
size_t size() const {
122129
return base.size();
123130
}
124131

125132
void addValues(
126133
const MapVector* mapVector,
127-
const SimpleVector<StringView>* mapKeys,
128-
const SimpleVector<S>* mapValues,
134+
const VectorPtr& mapKeys,
135+
const VectorPtr& mapValues,
129136
vector_size_t row,
130137
HashStringAllocator* allocator) {
138+
auto keys = mapKeys->template as<SimpleVector<StringView>>();
139+
auto values = mapValues->template as<SimpleVector<S>>();
131140
auto offset = mapVector->offsetAt(row);
132141
auto size = mapVector->sizeAt(row);
133142

134143
for (auto i = 0; i < size; ++i) {
135144
// Ignore null map keys.
136-
if (!mapKeys->isNullAt(offset + i)) {
137-
auto key = mapKeys->valueAt(offset + i);
145+
if (!keys->isNullAt(offset + i)) {
146+
auto key = keys->valueAt(offset + i);
138147

139148
if (!key.isInline()) {
140149
auto it = base.sums.find(key);
@@ -145,19 +154,95 @@ struct StringViewAccumulator {
145154
}
146155
}
147156

148-
base.addValue(key, mapValues, offset + i, mapValues->typeKind());
157+
base.addValue(key, values, offset + i, values->typeKind());
149158
}
150159
}
151160
}
152161

153162
vector_size_t extractValues(
154-
FlatVector<StringView>& mapKeys,
155-
FlatVector<S>& mapValues,
163+
VectorPtr& mapKeys,
164+
VectorPtr& mapValues,
156165
vector_size_t offset) {
157166
return base.extractValues(mapKeys, mapValues, offset);
158167
}
159168
};
160169

170+
/// Maintains a map with keys of type array, map or struct.
171+
template <typename V>
172+
struct ComplexTypeAccumulator {
173+
using ValueMap = folly::F14FastMap<
174+
AddressableNonNullValueList::Entry,
175+
int64_t,
176+
AddressableNonNullValueList::Hash,
177+
AddressableNonNullValueList::EqualTo,
178+
AlignedStlAllocator<
179+
std::pair<const AddressableNonNullValueList::Entry, int64_t>,
180+
16>>;
181+
182+
/// A set of pointers to values stored in AddressableNonNullValueList.
183+
ValueMap sums;
184+
185+
/// Stores unique non-null keys.
186+
AddressableNonNullValueList serializedKeys;
187+
188+
ComplexTypeAccumulator(const TypePtr& type, HashStringAllocator* allocator)
189+
: sums{
190+
0,
191+
AddressableNonNullValueList::Hash{},
192+
AddressableNonNullValueList::EqualTo{type},
193+
AlignedStlAllocator<
194+
std::pair<const AddressableNonNullValueList::Entry, int64_t>,
195+
16>(allocator)} {}
196+
197+
void addValues(
198+
const MapVector* mapVector,
199+
const VectorPtr& mapKeys,
200+
const VectorPtr& mapValues,
201+
vector_size_t row,
202+
HashStringAllocator* allocator) {
203+
auto offset = mapVector->offsetAt(row);
204+
auto size = mapVector->sizeAt(row);
205+
auto values = mapValues->template as<SimpleVector<V>>();
206+
207+
for (auto i = 0; i < size; ++i) {
208+
if (!mapKeys->isNullAt(offset + i)) {
209+
auto entry =
210+
serializedKeys.append(*mapKeys.get(), offset + i, allocator);
211+
212+
auto it = sums.find(entry);
213+
if (it == sums.end()) {
214+
// New entry.
215+
sums[entry] = values->valueAt(offset + i);
216+
} else {
217+
// Existing entry.
218+
sums[entry] += values->valueAt(offset + i);
219+
}
220+
}
221+
}
222+
}
223+
224+
vector_size_t extractValues(
225+
VectorPtr& mapKeys,
226+
VectorPtr& mapValues,
227+
vector_size_t offset) {
228+
auto values = mapValues->asFlatVector<V>();
229+
auto index = offset;
230+
231+
for (const auto& [position, count] : sums) {
232+
AddressableNonNullValueList::read(position, *mapKeys.get(), index);
233+
values->set(index, count);
234+
++index;
235+
}
236+
237+
return sums.size();
238+
}
239+
240+
size_t size() const {
241+
return sums.size();
242+
}
243+
};
244+
245+
// Defines unique accumulators dependent on type.
161246
template <typename K, typename S>
162247
struct AccumulatorTypeTraits {
163248
using AccumulatorType = Accumulator<K, S>;
@@ -168,6 +253,12 @@ struct AccumulatorTypeTraits<StringView, S> {
168253
using AccumulatorType = StringViewAccumulator<S>;
169254
};
170255

256+
template <typename V>
257+
struct AccumulatorTypeTraits<ComplexType, V> {
258+
using AccumulatorType = ComplexTypeAccumulator<V>;
259+
};
260+
261+
// Defines common aggregator.
171262
template <typename K, typename S>
172263
class MapUnionSumAggregate : public exec::Aggregate {
173264
public:
@@ -190,12 +281,18 @@ class MapUnionSumAggregate : public exec::Aggregate {
190281
VELOX_CHECK(mapVector);
191282
mapVector->resize(numGroups);
192283

193-
auto mapKeys = mapVector->mapKeys()->as<FlatVector<K>>();
194-
auto mapValues = mapVector->mapValues()->as<FlatVector<S>>();
284+
auto mapKeysPtr = mapVector->mapKeys();
285+
auto mapValuesPtr = mapVector->mapValues();
195286

196287
auto numElements = countElements(groups, numGroups);
197-
mapKeys->resize(numElements);
198-
mapValues->resize(numElements);
288+
mapVector->mapValues()->as<FlatVector<S>>()->resize(numElements);
289+
290+
// ComplexType cannot be resized the same.
291+
if constexpr (!std::is_same_v<K, ComplexType>) {
292+
mapVector->mapKeys()->as<FlatVector<K>>()->resize(numElements);
293+
} else {
294+
mapVector->mapKeys()->resize(numElements);
295+
}
199296

200297
auto rawNulls = mapVector->mutableRawNulls();
201298
vector_size_t offset = 0;
@@ -208,7 +305,7 @@ class MapUnionSumAggregate : public exec::Aggregate {
208305
clearNull(rawNulls, i);
209306

210307
auto mapSize = value<AccumulatorType>(group)->extractValues(
211-
*mapKeys, *mapValues, offset);
308+
mapKeysPtr, mapValuesPtr, offset);
212309
mapVector->setOffsetAndSize(i, offset, mapSize);
213310
offset += mapSize;
214311
}
@@ -227,8 +324,8 @@ class MapUnionSumAggregate : public exec::Aggregate {
227324
bool /*mayPushdown*/) override {
228325
decodedMaps_.decode(*args[0], rows);
229326
auto mapVector = decodedMaps_.base()->template as<MapVector>();
230-
auto mapKeys = mapVector->mapKeys()->template as<SimpleVector<K>>();
231-
auto mapValues = mapVector->mapValues()->template as<SimpleVector<S>>();
327+
auto mapKeys = mapVector->mapKeys();
328+
auto mapValues = mapVector->mapValues();
232329

233330
rows.applyToSelected([&](auto row) {
234331
if (!decodedMaps_.isNullAt(row)) {
@@ -249,8 +346,8 @@ class MapUnionSumAggregate : public exec::Aggregate {
249346
bool /* mayPushdown */) override {
250347
decodedMaps_.decode(*args[0], rows);
251348
auto mapVector = decodedMaps_.base()->template as<MapVector>();
252-
auto mapKeys = mapVector->mapKeys()->template as<SimpleVector<K>>();
253-
auto mapValues = mapVector->mapValues()->template as<SimpleVector<S>>();
349+
auto mapKeys = mapVector->mapKeys();
350+
auto mapValues = mapVector->mapValues();
254351

255352
auto groupMap = value<AccumulatorType>(group);
256353

@@ -285,7 +382,7 @@ class MapUnionSumAggregate : public exec::Aggregate {
285382
folly::Range<const vector_size_t*> indices) override {
286383
setAllNulls(groups, indices);
287384
for (auto index : indices) {
288-
new (groups[index] + offset_) AccumulatorType{allocator_};
385+
new (groups[index] + offset_) AccumulatorType{resultType_, allocator_};
289386
}
290387
}
291388

@@ -304,8 +401,8 @@ class MapUnionSumAggregate : public exec::Aggregate {
304401
void addMap(
305402
AccumulatorType& groupMap,
306403
const MapVector* mapVector,
307-
const SimpleVector<K>* mapKeys,
308-
const SimpleVector<S>* mapValues,
404+
const VectorPtr& mapKeys,
405+
const VectorPtr& mapValues,
309406
vector_size_t row) const {
310407
auto decodedRow = decodedMaps_.index(row);
311408
groupMap.addValues(mapVector, mapKeys, mapValues, decodedRow, allocator_);
@@ -340,7 +437,8 @@ std::unique_ptr<exec::Aggregate> createMapUnionSumAggregate(
340437
case TypeKind::DOUBLE:
341438
return std::make_unique<MapUnionSumAggregate<K, double>>(resultType);
342439
default:
343-
VELOX_UNREACHABLE();
440+
VELOX_UNREACHABLE(
441+
"Unexpected value type {}", mapTypeKindToName(valueKind));
344442
}
345443
}
346444

@@ -350,35 +448,14 @@ void registerMapUnionSumAggregate(
350448
const std::string& prefix,
351449
bool withCompanionFunctions,
352450
bool overwrite) {
353-
const std::vector<std::string> keyTypes = {
354-
"tinyint",
355-
"smallint",
356-
"integer",
357-
"bigint",
358-
"real",
359-
"double",
360-
"varchar",
361-
"json"};
362-
const std::vector<std::string> valueTypes = {
363-
"tinyint",
364-
"smallint",
365-
"integer",
366-
"bigint",
367-
"double",
368-
"real",
369-
};
370-
371-
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures;
372-
for (auto keyType : keyTypes) {
373-
for (auto valueType : valueTypes) {
374-
auto mapType = fmt::format("map({},{})", keyType, valueType);
375-
signatures.push_back(exec::AggregateFunctionSignatureBuilder()
376-
.returnType(mapType)
377-
.intermediateType(mapType)
378-
.argumentType(mapType)
379-
.build());
380-
}
381-
}
451+
std::vector<std::shared_ptr<exec::AggregateFunctionSignature>> signatures{
452+
exec::AggregateFunctionSignatureBuilder()
453+
.typeVariable("K")
454+
.typeVariable("V")
455+
.returnType("map(K,V)")
456+
.intermediateType("map(K,V)")
457+
.argumentType("map(K,V)")
458+
.build()};
382459

383460
auto name = prefix + kMapUnionSum;
384461
exec::registerAggregateFunction(
@@ -395,6 +472,8 @@ void registerMapUnionSumAggregate(
395472
auto& mapType = argTypes[0]->asMap();
396473
auto keyTypeKind = mapType.keyType()->kind();
397474
auto valueTypeKind = mapType.valueType()->kind();
475+
const auto keyType = resultType->childAt(0);
476+
398477
switch (keyTypeKind) {
399478
case TypeKind::TINYINT:
400479
return createMapUnionSumAggregate<int8_t>(
@@ -416,8 +495,14 @@ void registerMapUnionSumAggregate(
416495
case TypeKind::VARCHAR:
417496
return createMapUnionSumAggregate<StringView>(
418497
valueTypeKind, resultType);
498+
case TypeKind::ARRAY:
499+
case TypeKind::MAP:
500+
case TypeKind::ROW:
501+
return createMapUnionSumAggregate<ComplexType>(
502+
valueTypeKind, resultType);
419503
default:
420-
VELOX_UNREACHABLE();
504+
VELOX_UNREACHABLE(
505+
"Unexpected key type {}", mapTypeKindToName(keyTypeKind));
421506
}
422507
},
423508
withCompanionFunctions,

0 commit comments

Comments
 (0)