diff --git a/velox/dwio/dwrf/reader/ReaderBase.cpp b/velox/dwio/dwrf/reader/ReaderBase.cpp index 71dbc2e41510..16bd779b500a 100644 --- a/velox/dwio/dwrf/reader/ReaderBase.cpp +++ b/velox/dwio/dwrf/reader/ReaderBase.cpp @@ -343,13 +343,19 @@ std::shared_ptr ReaderBase::convertType( return SMALLINT(); case TypeKind::INTEGER: return INTEGER(); - case TypeKind::BIGINT: + case TypeKind::BIGINT: { + TypePtr converted; if (type.format() == DwrfFormat::kOrc && type.getOrcPtr()->kind() == proto::orc::Type_Kind_DECIMAL) { - return DECIMAL( - type.getOrcPtr()->precision(), type.getOrcPtr()->scale()); + converted = + DECIMAL(type.getOrcPtr()->precision(), type.getOrcPtr()->scale()); + } else { + converted = BIGINT(); + common::testutil::TestValue::adjust( + "facebook::velox::dwrf::ReaderBase::convertType", &converted); } - return BIGINT(); + return converted; + } case TypeKind::HUGEINT: if (type.format() == DwrfFormat::kOrc && type.getOrcPtr()->kind() == proto::orc::Type_Kind_DECIMAL) { diff --git a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp index 4f6c19d8c445..644bc07de1c9 100644 --- a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp +++ b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.cpp @@ -67,16 +67,17 @@ void SelectiveDecimalColumnReader::seekToRowGroup(int64_t index) { template template -void SelectiveDecimalColumnReader::readHelper(RowSet rows) { - vector_size_t numRows = rows.back() + 1; +void SelectiveDecimalColumnReader::readHelper( + common::Filter* filter, + RowSet rows) { ExtractToReader extractValues(this); - common::AlwaysTrue filter; + common::AlwaysTrue alwaysTrue; DirectRleColumnVisitor< int64_t, common::AlwaysTrue, decltype(extractValues), kDense> - visitor(filter, this, rows, extractValues); + visitor(alwaysTrue, this, rows, extractValues); // decode scale stream if (version_ == velox::dwrf::RleVersion_1) { @@ -96,14 +97,161 @@ void SelectiveDecimalColumnReader::readHelper(RowSet rows) { // reset numValues_ before reading values numValues_ = 0; valueSize_ = sizeof(DataT); + vector_size_t numRows = rows.back() + 1; ensureValuesCapacity(numRows); // decode value stream facebook::velox::dwio::common:: ColumnVisitor - valueVisitor(filter, this, rows, extractValues); + valueVisitor(alwaysTrue, this, rows, extractValues); decodeWithVisitor>(valueDecoder_.get(), valueVisitor); readOffset_ += numRows; + + // Fill decimals before applying filter. + fillDecimals(); + + const auto rawNulls = nullsInReadRange_ + ? (kDense ? nullsInReadRange_->as() : rawResultNulls_) + : nullptr; + // Process filter. + process(filter, rows, rawNulls); +} + +template +void SelectiveDecimalColumnReader::processNulls( + bool isNull, + const RowSet& rows, + const uint64_t* rawNulls) { + if (!rawNulls) { + return; + } + returnReaderNulls_ = false; + anyNulls_ = !isNull; + allNull_ = isNull; + + auto rawDecimal = values_->asMutable(); + auto rawScale = scaleBuffer_->asMutable(); + + vector_size_t idx = 0; + if (isNull) { + for (vector_size_t i = 0; i < numValues_; i++) { + if (bits::isBitNull(rawNulls, i)) { + bits::setNull(rawResultNulls_, idx); + addOutputRow(rows[i]); + idx++; + } + } + } else { + for (vector_size_t i = 0; i < numValues_; i++) { + if (!bits::isBitNull(rawNulls, i)) { + bits::setNull(rawResultNulls_, idx, false); + rawDecimal[idx] = rawDecimal[i]; + rawScale[idx] = rawScale[i]; + addOutputRow(rows[i]); + idx++; + } + } + } +} + +template +void SelectiveDecimalColumnReader::processFilter( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls) { + VELOX_CHECK_NOT_NULL(filter, "Filter must not be null."); + returnReaderNulls_ = false; + anyNulls_ = false; + allNull_ = true; + + vector_size_t idx = 0; + auto rawDecimal = values_->asMutable(); + for (vector_size_t i = 0; i < numValues_; i++) { + if (rawNulls && bits::isBitNull(rawNulls, i)) { + if (filter->testNull()) { + bits::setNull(rawResultNulls_, idx); + addOutputRow(rows[i]); + anyNulls_ = true; + idx++; + } + } else { + bool tested; + if constexpr (std::is_same_v) { + tested = filter->testInt64(rawDecimal[i]); + } else { + tested = filter->testInt128(rawDecimal[i]); + } + + if (tested) { + if (rawNulls) { + bits::setNull(rawResultNulls_, idx, false); + } + rawDecimal[idx] = rawDecimal[i]; + addOutputRow(rows[i]); + allNull_ = false; + idx++; + } + } + } +} + +template +void SelectiveDecimalColumnReader::process( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls) { + // Treat the filter as kAlwaysTrue if any of the following conditions are met: + // 1) No filter found; + // 2) Filter is kIsNotNull but rawNulls == NULL (no elements is null). + auto filterKind = + !filter || (filter->kind() == common::FilterKind::kIsNotNull && !rawNulls) + ? common::FilterKind::kAlwaysTrue + : filter->kind(); + switch (filterKind) { + case common::FilterKind::kAlwaysTrue: + // Simply add all rows to output. + for (vector_size_t i = 0; i < numValues_; i++) { + addOutputRow(rows[i]); + } + break; + case common::FilterKind::kIsNull: + processNulls(true, rows, rawNulls); + break; + case common::FilterKind::kIsNotNull: + processNulls(false, rows, rawNulls); + break; + case common::FilterKind::kBigintRange: + case common::FilterKind::kBigintValuesUsingHashTable: + case common::FilterKind::kBigintValuesUsingBitmask: + case common::FilterKind::kNegatedBigintRange: + case common::FilterKind::kNegatedBigintValuesUsingHashTable: + case common::FilterKind::kNegatedBigintValuesUsingBitmask: + case common::FilterKind::kBigintMultiRange: { + if constexpr (std::is_same_v) { + processFilter(filter, rows, rawNulls); + } else { + const auto actualType = CppToType::create(); + VELOX_NYI( + "Expected type BIGINT, but found file type {}.", + actualType->toString()); + } + break; + } + case common::FilterKind::kHugeintValuesUsingHashTable: + case common::FilterKind::kHugeintRange: { + if constexpr (std::is_same_v) { + processFilter(filter, rows, rawNulls); + } else { + const auto actualType = CppToType::create(); + VELOX_NYI( + "Expected type HUGEINT, but found file type {}.", + actualType->toString()); + } + break; + } + default: + VELOX_NYI("Unsupported filter: {}.", static_cast(filterKind)); + } } template @@ -111,14 +259,20 @@ void SelectiveDecimalColumnReader::read( int64_t offset, const RowSet& rows, const uint64_t* incomingNulls) { - VELOX_CHECK(!scanSpec_->filter()); VELOX_CHECK(!scanSpec_->valueHook()); prepareRead(offset, rows, incomingNulls); + if (!resultNulls_ || !resultNulls_->unique() || + resultNulls_->capacity() * 8 < rows.size()) { + // Make sure a dedicated resultNulls_ is allocated with enough capacity as + // RleDecoder always assumes it is available. + resultNulls_ = AlignedBuffer::allocate(rows.size(), memoryPool_); + rawResultNulls_ = resultNulls_->asMutable(); + } bool isDense = rows.back() == rows.size() - 1; if (isDense) { - readHelper(rows); + readHelper(scanSpec_->filter(), rows); } else { - readHelper(rows); + readHelper(scanSpec_->filter(), rows); } } @@ -126,16 +280,18 @@ template void SelectiveDecimalColumnReader::getValues( const RowSet& rows, VectorPtr* result) { + rawValues_ = values_->asMutable(); + getIntValues(rows, requestedType_, result); +} + +template +void SelectiveDecimalColumnReader::fillDecimals() { auto nullsPtr = resultNulls() ? resultNulls()->template as() : nullptr; auto scales = scaleBuffer_->as(); auto values = values_->asMutable(); - DecimalUtil::fillDecimals( values, nullsPtr, values, scales, numValues_, scale_); - - rawValues_ = values_->asMutable(); - getIntValues(rows, requestedType_, result); } template class SelectiveDecimalColumnReader; diff --git a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h index 67a82b051e36..4482ef47fc50 100644 --- a/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h +++ b/velox/dwio/dwrf/reader/SelectiveDecimalColumnReader.h @@ -49,7 +49,24 @@ class SelectiveDecimalColumnReader : public SelectiveColumnReader { private: template - void readHelper(RowSet rows); + void readHelper(common::Filter* filter, RowSet rows); + + // Process IsNull and IsNotNull filters. + void processNulls(bool isNull, const RowSet& rows, const uint64_t* rawNulls); + + // Process filters on decimal values. + void processFilter( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls); + + // Dispatch to the respective filter processing based on the filter type. + void process( + const common::Filter* filter, + const RowSet& rows, + const uint64_t* rawNulls); + + void fillDecimals(); std::unique_ptr> valueDecoder_; std::unique_ptr> scaleDecoder_; diff --git a/velox/dwio/dwrf/test/E2EFilterTest.cpp b/velox/dwio/dwrf/test/E2EFilterTest.cpp index 43b67e91e550..6a56da80e891 100644 --- a/velox/dwio/dwrf/test/E2EFilterTest.cpp +++ b/velox/dwio/dwrf/test/E2EFilterTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/base/Portability.h" +#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/E2EFilterTestBase.h" #include "velox/dwio/dwrf/reader/DwrfReader.h" @@ -64,11 +65,11 @@ class E2EFilterTest : public E2EFilterTestBase { const TypePtr& type, const std::vector& batches, bool forRowGroupSkip = false) override { - auto options = createWriterOptions(type); + setWriterOptions(type); int32_t flushCounter = 0; // If we test row group skip, we have all the data in one stripe. For // scan, we start a stripe every 'flushEveryNBatches_' batches. - options.flushPolicyFactory = [&]() { + options_.flushPolicyFactory = [&]() { return std::make_unique([&]() { return forRowGroupSkip ? false : (++flushCounter % flushEveryNBatches_ == 0); @@ -80,8 +81,8 @@ class E2EFilterTest : public E2EFilterTestBase { dwio::common::FileSink::Options{.pool = leafPool_.get()}); ASSERT_TRUE(sink->isBuffered()); auto* sinkPtr = sink.get(); - options.memoryPool = rootPool_.get(); - writer_ = std::make_unique(std::move(sink), options); + options_.memoryPool = rootPool_.get(); + writer_ = std::make_unique(std::move(sink), options_); for (auto& batch : batches) { writer_->write(batch); } @@ -105,9 +106,10 @@ class E2EFilterTest : public E2EFilterTestBase { } std::unordered_set flatMapColumns_; + dwrf::WriterOptions options_; private: - dwrf::WriterOptions createWriterOptions(const TypePtr& type) { + void setWriterOptions(const TypePtr& type) { auto config = std::make_shared(); config->set(dwrf::Config::COMPRESSION, CompressionKind_NONE); config->set(dwrf::Config::USE_VINTS, useVInts_); @@ -148,10 +150,8 @@ class E2EFilterTest : public E2EFilterTestBase { config->set>>( dwrf::Config::MAP_FLAT_COLS_STRUCT_KEYS, mapFlatColsStructKeys); } - dwrf::WriterOptions options; - options.config = config; - options.schema = writerSchema; - return options; + options_.config = config; + options_.schema = writerSchema; } std::unique_ptr writer_; @@ -227,6 +227,74 @@ TEST_F(E2EFilterTest, byteRle) { 20); } +DEBUG_ONLY_TEST_F(E2EFilterTest, shortDecimal) { + testutil::TestValue::enable(); + options_.format = DwrfFormat::kOrc; + const std::unordered_map types = { + {"shortdecimal_val:decimal(8, 5)", DECIMAL(8, 5)}, + {"shortdecimal_val:decimal(10, 5)", DECIMAL(10, 5)}, + {"shortdecimal_val:decimal(17, 5)", DECIMAL(17, 5)}}; + + for (const auto& pair : types) { + SCOPED_TESTVALUE_SET( + "facebook::velox::dwrf::ReaderBase::convertType", + std::function( + [&](TypePtr* type) { *type = pair.second; })); + testWithTypes( + pair.first, + [&]() { + makeIntDistribution( + "shortdecimal_val", + 10, // min + 100, // max + 22, // repeats + 19, // rareFrequency + -999, // rareMin + 30000, // rareMax + true); + }, + false, + {"shortdecimal_val"}, + 20); + } + options_.format = DwrfFormat::kDwrf; +} + +DEBUG_ONLY_TEST_F(E2EFilterTest, longDecimal) { + testutil::TestValue::enable(); + options_.format = DwrfFormat::kOrc; + const std::unordered_map types = { + {"longdecimal_val:decimal(30, 10)", DECIMAL(30, 10)}, + {"longdecimal_val:decimal(37, 15)", DECIMAL(37, 15)}}; + + SCOPED_TESTVALUE_SET( + "facebook::velox::dwrf::ProtoUtils::writeType", + std::function([&](bool* kindSet) { *kindSet = true; })); + for (const auto& pair : types) { + SCOPED_TESTVALUE_SET( + "facebook::velox::dwrf::ReaderBase::convertType", + std::function( + [&](TypePtr* type) { *type = pair.second; })); + testWithTypes( + pair.first, + [&]() { + makeIntDistribution( + "longdecimal_val", + 10, // min + 100, // max + 22, // repeats + 19, // rareFrequency + -999, // rareMin + 30000, // rareMax + true); + }, + false, + {"longdecimal_val"}, + 20); + } + options_.format = DwrfFormat::kDwrf; +} + TEST_F(E2EFilterTest, floatAndDouble) { testWithTypes( "float_val:float," diff --git a/velox/dwio/dwrf/utils/ProtoUtils.cpp b/velox/dwio/dwrf/utils/ProtoUtils.cpp index 405d2e79ddfb..e907c4126d3e 100644 --- a/velox/dwio/dwrf/utils/ProtoUtils.cpp +++ b/velox/dwio/dwrf/utils/ProtoUtils.cpp @@ -57,9 +57,22 @@ void ProtoUtils::writeType( if (parent) { parent->add_subtypes(footer.types_size() - 1); } - auto kind = - VELOX_STATIC_FIELD_DYNAMIC_DISPATCH(SchemaType, kind, type.kind()); - self->set_kind(kind); + bool kindSet = false; + if (type.kind() == TypeKind::HUGEINT) { + // Hugeint is not supported by DWRF, and this branch is only for ORC + // testing before the ORC footer write is implemented. + auto kind = SchemaType::kind; + self->set_kind(kind); + common::testutil::TestValue::adjust( + "facebook::velox::dwrf::ProtoUtils::writeType", &kindSet); + } else { + auto kind = + VELOX_STATIC_FIELD_DYNAMIC_DISPATCH(SchemaType, kind, type.kind()); + self->set_kind(kind); + kindSet = true; + } + VELOX_CHECK(kindSet, "Unknown type {}.", type.toString()); + switch (type.kind()) { case TypeKind::ROW: { auto& row = type.asRow(); diff --git a/velox/dwio/dwrf/writer/ColumnWriter.cpp b/velox/dwio/dwrf/writer/ColumnWriter.cpp index 2a4cf2077961..a1084f4d6104 100644 --- a/velox/dwio/dwrf/writer/ColumnWriter.cpp +++ b/velox/dwio/dwrf/writer/ColumnWriter.cpp @@ -2183,7 +2183,12 @@ std::unique_ptr BaseColumnWriter::create( context, type, sequence, onRecordPosition); ret->children_.reserve(type.size()); for (int32_t i = 0; i < type.size(); ++i) { - ret->children_.push_back(create(context, *type.childAt(i), sequence)); + ret->children_.push_back(create( + context, + *type.childAt(i), + sequence, + /*onRecordPosition=*/nullptr, + format)); } return ret; } @@ -2199,15 +2204,30 @@ std::unique_ptr BaseColumnWriter::create( } auto ret = std::make_unique( context, type, sequence, onRecordPosition); - ret->children_.push_back(create(context, *type.childAt(0), sequence)); - ret->children_.push_back(create(context, *type.childAt(1), sequence)); + ret->children_.push_back(create( + context, + *type.childAt(0), + sequence, + /*onRecordPosition=*/nullptr, + format)); + ret->children_.push_back(create( + context, + *type.childAt(1), + sequence, + /*onRecordPosition=*/nullptr, + format)); return ret; } case TypeKind::ARRAY: { VELOX_CHECK_EQ(type.size(), 1, "Array should have exactly one child"); auto ret = std::make_unique( context, type, sequence, onRecordPosition); - ret->children_.push_back(create(context, *type.childAt(0), sequence)); + ret->children_.push_back(create( + context, + *type.childAt(0), + sequence, + /*onRecordPosition=*/nullptr, + format)); return ret; } default: diff --git a/velox/dwio/dwrf/writer/Writer.cpp b/velox/dwio/dwrf/writer/Writer.cpp index d6011c38f8de..b5af93a2cc1b 100644 --- a/velox/dwio/dwrf/writer/Writer.cpp +++ b/velox/dwio/dwrf/writer/Writer.cpp @@ -200,7 +200,12 @@ Writer::Writer( } if (options.columnWriterFactory == nullptr) { - writer_ = BaseColumnWriter::create(writerBase_->getContext(), *schema_); + writer_ = BaseColumnWriter::create( + writerBase_->getContext(), + *schema_, + /*sequence=*/0, + /*onRecordPosition=*/nullptr, + options.format); } else { writer_ = options.columnWriterFactory(writerBase_->getContext(), *schema_); }