From 3a1a60af69ad9d2f9060ce0c22aa212224fd6053 Mon Sep 17 00:00:00 2001 From: Jialiang Tan Date: Wed, 25 Sep 2024 00:12:17 -0700 Subject: [PATCH] Refactor spilling for RowNumber (#11082) Summary: Simplify the logic and cleanup the unnecessary code in RowNumber spilling Pull Request resolved: https://github.com/facebookincubator/velox/pull/11082 Reviewed By: xiaoxmeng Differential Revision: D63336185 Pulled By: tanjialiang fbshipit-source-id: 993c18d29f42029f43931b2be7e3313598836d86 --- velox/exec/RowNumber.cpp | 68 +++++++++++++--------------------------- velox/exec/RowNumber.h | 6 ++-- 2 files changed, 25 insertions(+), 49 deletions(-) diff --git a/velox/exec/RowNumber.cpp b/velox/exec/RowNumber.cpp index 9f036d613181..4836a1f8b836 100644 --- a/velox/exec/RowNumber.cpp +++ b/velox/exec/RowNumber.cpp @@ -96,42 +96,22 @@ void RowNumber::addInput(RowVectorPtr input) { input_ = std::move(input); } -void RowNumber::addSpillInput() { - VELOX_CHECK_NOT_NULL(input_); - VELOX_CHECK_NULL(inputSpiller_); - ensureInputFits(input_); - if (input_ == nullptr) { - VELOX_CHECK_NOT_NULL(inputSpiller_); - // Memory arbitration might be triggered by ensureInputFits() which will - // spill 'input_'. - return; - } - - const auto numInput = input_->size(); - SelectivityVector rows(numInput); - - VELOX_CHECK(spillConfig_.has_value()); - table_->prepareForGroupProbe( - *lookup_, input_, rows, spillConfig_->startPartitionBit); - table_->groupProbe(*lookup_, spillConfig_->startPartitionBit); - - // Initialize new partitions with zeros. - for (auto i : lookup_->newGroups) { - setNumRows(lookup_->hits[i], 0); - } -} - void RowNumber::noMoreInput() { Operator::noMoreInput(); if (inputSpiller_ != nullptr) { - inputSpiller_->finishSpill(spillInputPartitionSet_); - inputSpiller_.reset(); - removeEmptyPartitions(spillInputPartitionSet_); - restoreNextSpillPartition(); + finishSpillInputAndRestoreNext(); } } +void RowNumber::finishSpillInputAndRestoreNext() { + VELOX_CHECK_NOT_NULL(inputSpiller_); + inputSpiller_->finishSpill(spillInputPartitionSet_); + inputSpiller_.reset(); + removeEmptyPartitions(spillInputPartitionSet_); + restoreNextSpillPartition(); +} + void RowNumber::restoreNextSpillPartition() { if (spillInputPartitionSet_.empty()) { return; @@ -181,10 +161,11 @@ void RowNumber::restoreNextSpillPartition() { spillInputPartitionSet_.erase(it); - spillInputReader_->nextBatch(input_); - VELOX_CHECK_NOT_NULL(input_); + RowVectorPtr unspilledInput; + spillInputReader_->nextBatch(unspilledInput); + VELOX_CHECK_NOT_NULL(unspilledInput); // NOTE: spillInputReader_ will at least produce one batch output. - addSpillInput(); + addInput(std::move(unspilledInput)); } void RowNumber::ensureInputFits(const RowVectorPtr& input) { @@ -339,19 +320,17 @@ RowVectorPtr RowNumber::getOutput() { output = fillOutput(numInput, nullptr); } + input_ = nullptr; if (spillInputReader_ != nullptr) { - if (spillInputReader_->nextBatch(input_)) { - addSpillInput(); + RowVectorPtr unspilledInput; + if (spillInputReader_->nextBatch(unspilledInput)) { + addInput(std::move(unspilledInput)); } else { - input_ = nullptr; spillInputReader_ = nullptr; table_->clear(); restoreNextSpillPartition(); } - } else { - input_ = nullptr; } - return output; } @@ -522,9 +501,9 @@ void RowNumber::spillInput( } void RowNumber::recursiveSpillInput() { - RowVectorPtr input; - while (spillInputReader_->nextBatch(input)) { - spillInput(input, pool()); + RowVectorPtr unspilledInput; + while (spillInputReader_->nextBatch(unspilledInput)) { + spillInput(unspilledInput, pool()); if (operatorCtx_->driver()->shouldYield()) { yield_ = true; @@ -532,12 +511,7 @@ void RowNumber::recursiveSpillInput() { } } - inputSpiller_->finishSpill(spillInputPartitionSet_); - inputSpiller_.reset(); - spillInputReader_ = nullptr; - - removeEmptyPartitions(spillInputPartitionSet_); - restoreNextSpillPartition(); + finishSpillInputAndRestoreNext(); } void RowNumber::setSpillPartitionBits( diff --git a/velox/exec/RowNumber.h b/velox/exec/RowNumber.h index 4db7d3ca82fd..fa87d5391e8e 100644 --- a/velox/exec/RowNumber.h +++ b/velox/exec/RowNumber.h @@ -64,8 +64,6 @@ class RowNumber : public Operator { void spill(); - void addSpillInput(); - void restoreNextSpillPartition(); SpillPartitionNumSet spillHashTable(); @@ -78,6 +76,10 @@ class RowNumber : public Operator { FlatVector& getOrCreateRowNumberVector(vector_size_t size); + // Finishes the current input spilling and restore the next processing + // partition. + void finishSpillInputAndRestoreNext(); + // Used by recursive spill processing to read the spilled input data from the // previous spill run through 'spillInputReader_' and then spill them back // into a number of sub-partitions. After that, the function restores one of