diff --git a/velox/dwio/dwrf/reader/ReaderBase.cpp b/velox/dwio/dwrf/reader/ReaderBase.cpp index 71dbc2e41510..16bd779b500a 100644 --- a/velox/dwio/dwrf/reader/ReaderBase.cpp +++ b/velox/dwio/dwrf/reader/ReaderBase.cpp @@ -343,13 +343,19 @@ std::shared_ptr ReaderBase::convertType( return SMALLINT(); case TypeKind::INTEGER: return INTEGER(); - case TypeKind::BIGINT: + case TypeKind::BIGINT: { + TypePtr converted; if (type.format() == DwrfFormat::kOrc && type.getOrcPtr()->kind() == proto::orc::Type_Kind_DECIMAL) { - return DECIMAL( - type.getOrcPtr()->precision(), type.getOrcPtr()->scale()); + converted = + DECIMAL(type.getOrcPtr()->precision(), type.getOrcPtr()->scale()); + } else { + converted = BIGINT(); + common::testutil::TestValue::adjust( + "facebook::velox::dwrf::ReaderBase::convertType", &converted); } - return BIGINT(); + return converted; + } case TypeKind::HUGEINT: if (type.format() == DwrfFormat::kOrc && type.getOrcPtr()->kind() == proto::orc::Type_Kind_DECIMAL) { diff --git a/velox/dwio/dwrf/test/E2EFilterTest.cpp b/velox/dwio/dwrf/test/E2EFilterTest.cpp index 43b67e91e550..c3790c8556fb 100644 --- a/velox/dwio/dwrf/test/E2EFilterTest.cpp +++ b/velox/dwio/dwrf/test/E2EFilterTest.cpp @@ -15,6 +15,7 @@ */ #include "velox/common/base/Portability.h" +#include "velox/common/base/tests/GTestUtils.h" #include "velox/common/testutil/TestValue.h" #include "velox/dwio/common/tests/utils/E2EFilterTestBase.h" #include "velox/dwio/dwrf/reader/DwrfReader.h" @@ -64,11 +65,11 @@ class E2EFilterTest : public E2EFilterTestBase { const TypePtr& type, const std::vector& batches, bool forRowGroupSkip = false) override { - auto options = createWriterOptions(type); + setWriterOptions(type); int32_t flushCounter = 0; // If we test row group skip, we have all the data in one stripe. For // scan, we start a stripe every 'flushEveryNBatches_' batches. - options.flushPolicyFactory = [&]() { + options_.flushPolicyFactory = [&]() { return std::make_unique([&]() { return forRowGroupSkip ? false : (++flushCounter % flushEveryNBatches_ == 0); @@ -80,8 +81,8 @@ class E2EFilterTest : public E2EFilterTestBase { dwio::common::FileSink::Options{.pool = leafPool_.get()}); ASSERT_TRUE(sink->isBuffered()); auto* sinkPtr = sink.get(); - options.memoryPool = rootPool_.get(); - writer_ = std::make_unique(std::move(sink), options); + options_.memoryPool = rootPool_.get(); + writer_ = std::make_unique(std::move(sink), options_); for (auto& batch : batches) { writer_->write(batch); } @@ -105,9 +106,10 @@ class E2EFilterTest : public E2EFilterTestBase { } std::unordered_set flatMapColumns_; + dwrf::WriterOptions options_; private: - dwrf::WriterOptions createWriterOptions(const TypePtr& type) { + void setWriterOptions(const TypePtr& type) { auto config = std::make_shared(); config->set(dwrf::Config::COMPRESSION, CompressionKind_NONE); config->set(dwrf::Config::USE_VINTS, useVInts_); @@ -148,10 +150,8 @@ class E2EFilterTest : public E2EFilterTestBase { config->set>>( dwrf::Config::MAP_FLAT_COLS_STRUCT_KEYS, mapFlatColsStructKeys); } - dwrf::WriterOptions options; - options.config = config; - options.schema = writerSchema; - return options; + options_.config = config; + options_.schema = writerSchema; } std::unique_ptr writer_; @@ -227,6 +227,74 @@ TEST_F(E2EFilterTest, byteRle) { 20); } +DEBUG_ONLY_TEST_F(E2EFilterTest, shortDecimal) { + testutil::TestValue::enable(); + options_.format = DwrfFormat::kOrc; + const std::unordered_map types = { + {"shortdecimal_val:decimal(8, 5)", DECIMAL(8, 5)}, + {"shortdecimal_val:decimal(10, 5)", DECIMAL(10, 5)}, + {"shortdecimal_val:decimal(17, 5)", DECIMAL(17, 5)}}; + + for (const auto& pair : types) { + SCOPED_TESTVALUE_SET( + "facebook::velox::dwrf::ReaderBase::convertType", + std::function( + [&](TypePtr* type) { *type = pair.second; })); + testWithTypes( + pair.first, + [&]() { + makeIntDistribution( + "shortdecimal_val", + 10, // min + 100, // max + 22, // repeats + 19, // rareFrequency + -999, // rareMin + 30000, // rareMax + true); + }, + false, + {"shortdecimal_val"}, + 20); + } + options_.format = DwrfFormat::kDwrf; +} + +DEBUG_ONLY_TEST_F(E2EFilterTest, longDecimal) { + testutil::TestValue::enable(); + options_.format = DwrfFormat::kOrc; + const std::unordered_map types = { + {"longdecimal_val:decimal(30, 10)", DECIMAL(30, 10)}, + {"longdecimal_val:decimal(37, 15)", DECIMAL(37, 15)}}; + + SCOPED_TESTVALUE_SET( + "facebook::velox::dwrf::ProtoUtils::writeType", + std::function([&](bool* kindSet) { *kindSet = true; })); + for (const auto& pair : types) { + SCOPED_TESTVALUE_SET( + "facebook::velox::dwrf::ReaderBase::convertType", + std::function( + [&](TypePtr* type) { *type = pair.second; })); + testWithTypes( + pair.first, + [&]() { + makeIntDistribution( + "longdecimal_val", + 10, // min + 100, // max + 22, // repeats + 19, // rareFrequency + -999, // rareMin + 30000, // rareMax + true); + }, + true, + {"longdecimal_val"}, + 20); + } + options_.format = DwrfFormat::kDwrf; +} + TEST_F(E2EFilterTest, floatAndDouble) { testWithTypes( "float_val:float," diff --git a/velox/dwio/dwrf/utils/ProtoUtils.cpp b/velox/dwio/dwrf/utils/ProtoUtils.cpp index 405d2e79ddfb..e907c4126d3e 100644 --- a/velox/dwio/dwrf/utils/ProtoUtils.cpp +++ b/velox/dwio/dwrf/utils/ProtoUtils.cpp @@ -57,9 +57,22 @@ void ProtoUtils::writeType( if (parent) { parent->add_subtypes(footer.types_size() - 1); } - auto kind = - VELOX_STATIC_FIELD_DYNAMIC_DISPATCH(SchemaType, kind, type.kind()); - self->set_kind(kind); + bool kindSet = false; + if (type.kind() == TypeKind::HUGEINT) { + // Hugeint is not supported by DWRF, and this branch is only for ORC + // testing before the ORC footer write is implemented. + auto kind = SchemaType::kind; + self->set_kind(kind); + common::testutil::TestValue::adjust( + "facebook::velox::dwrf::ProtoUtils::writeType", &kindSet); + } else { + auto kind = + VELOX_STATIC_FIELD_DYNAMIC_DISPATCH(SchemaType, kind, type.kind()); + self->set_kind(kind); + kindSet = true; + } + VELOX_CHECK(kindSet, "Unknown type {}.", type.toString()); + switch (type.kind()) { case TypeKind::ROW: { auto& row = type.asRow(); diff --git a/velox/dwio/dwrf/writer/ColumnWriter.cpp b/velox/dwio/dwrf/writer/ColumnWriter.cpp index 2a4cf2077961..a1084f4d6104 100644 --- a/velox/dwio/dwrf/writer/ColumnWriter.cpp +++ b/velox/dwio/dwrf/writer/ColumnWriter.cpp @@ -2183,7 +2183,12 @@ std::unique_ptr BaseColumnWriter::create( context, type, sequence, onRecordPosition); ret->children_.reserve(type.size()); for (int32_t i = 0; i < type.size(); ++i) { - ret->children_.push_back(create(context, *type.childAt(i), sequence)); + ret->children_.push_back(create( + context, + *type.childAt(i), + sequence, + /*onRecordPosition=*/nullptr, + format)); } return ret; } @@ -2199,15 +2204,30 @@ std::unique_ptr BaseColumnWriter::create( } auto ret = std::make_unique( context, type, sequence, onRecordPosition); - ret->children_.push_back(create(context, *type.childAt(0), sequence)); - ret->children_.push_back(create(context, *type.childAt(1), sequence)); + ret->children_.push_back(create( + context, + *type.childAt(0), + sequence, + /*onRecordPosition=*/nullptr, + format)); + ret->children_.push_back(create( + context, + *type.childAt(1), + sequence, + /*onRecordPosition=*/nullptr, + format)); return ret; } case TypeKind::ARRAY: { VELOX_CHECK_EQ(type.size(), 1, "Array should have exactly one child"); auto ret = std::make_unique( context, type, sequence, onRecordPosition); - ret->children_.push_back(create(context, *type.childAt(0), sequence)); + ret->children_.push_back(create( + context, + *type.childAt(0), + sequence, + /*onRecordPosition=*/nullptr, + format)); return ret; } default: diff --git a/velox/dwio/dwrf/writer/Writer.cpp b/velox/dwio/dwrf/writer/Writer.cpp index d6011c38f8de..b5af93a2cc1b 100644 --- a/velox/dwio/dwrf/writer/Writer.cpp +++ b/velox/dwio/dwrf/writer/Writer.cpp @@ -200,7 +200,12 @@ Writer::Writer( } if (options.columnWriterFactory == nullptr) { - writer_ = BaseColumnWriter::create(writerBase_->getContext(), *schema_); + writer_ = BaseColumnWriter::create( + writerBase_->getContext(), + *schema_, + /*sequence=*/0, + /*onRecordPosition=*/nullptr, + options.format); } else { writer_ = options.columnWriterFactory(writerBase_->getContext(), *schema_); }