|
31 | 31 | #include <Parser/AdvancedParametersParseUtil.h>
|
32 | 32 | #include <Parser/ExpressionParser.h>
|
33 | 33 | #include <Parsers/ASTIdentifier.h>
|
| 34 | +#include <Parser/SubstraitParserUtils.h> |
34 | 35 | #include <Processors/QueryPlan/ExpressionStep.h>
|
35 | 36 | #include <Processors/QueryPlan/FilterStep.h>
|
36 | 37 | #include <Processors/QueryPlan/JoinStep.h>
|
@@ -114,10 +115,9 @@ std::unordered_set<DB::JoinTableSide> JoinRelParser::extractTableSidesFromExpres
|
114 | 115 | table_sides.insert(table_sides_from_arg.begin(), table_sides_from_arg.end());
|
115 | 116 | }
|
116 | 117 | }
|
117 |
| - else if (expr.has_selection() && expr.selection().has_direct_reference() && expr.selection().direct_reference().has_struct_field()) |
| 118 | + else if (auto field = SubstraitParserUtils::getStructFieldIndex(expr)) |
118 | 119 | {
|
119 |
| - auto pos = expr.selection().direct_reference().struct_field().field(); |
120 |
| - if (pos < left_header.columns()) |
| 120 | + if (*field < left_header.columns()) |
121 | 121 | table_sides.insert(DB::JoinTableSide::Left);
|
122 | 122 | else
|
123 | 123 | table_sides.insert(DB::JoinTableSide::Right);
|
@@ -272,15 +272,10 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q
|
272 | 272 | auto input_header = left->getCurrentHeader();
|
273 | 273 | DB::ActionsDAG filter_is_not_null_dag{input_header.getColumnsWithTypeAndName()};
|
274 | 274 | // when is_null_aware_anti_join is true, there is only one join key
|
275 |
| - const auto * key_field = filter_is_not_null_dag.getInputs()[join.expression() |
276 |
| - .scalar_function() |
277 |
| - .arguments() |
278 |
| - .at(0) |
279 |
| - .value() |
280 |
| - .selection() |
281 |
| - .direct_reference() |
282 |
| - .struct_field() |
283 |
| - .field()]; |
| 275 | + auto field_index = SubstraitParserUtils::getStructFieldIndex(join.expression().scalar_function().arguments(0).value()); |
| 276 | + if (!field_index) |
| 277 | + throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The join key is not found in the expression."); |
| 278 | + const auto * key_field = filter_is_not_null_dag.getInputs()[*field_index]; |
284 | 279 |
|
285 | 280 | auto result_node = filter_is_not_null_dag.tryFindInOutputs(key_field->result_name);
|
286 | 281 | // add a function isNotNull to filter the null key on the left side
|
@@ -480,12 +475,12 @@ void JoinRelParser::collectJoinKeys(
|
480 | 475 | size_t left_pos = 0, right_pos = 0;
|
481 | 476 | for (const auto & arg : current_expr->scalar_function().arguments())
|
482 | 477 | {
|
483 |
| - if (!arg.value().has_selection() || !arg.value().selection().has_direct_reference() |
484 |
| - || !arg.value().selection().direct_reference().has_struct_field()) |
| 478 | + auto field_index = SubstraitParserUtils::getStructFieldIndex(arg.value()); |
| 479 | + if (!field_index) |
485 | 480 | {
|
486 | 481 | throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A column reference is expected");
|
487 | 482 | }
|
488 |
| - auto col_pos_ref = arg.value().selection().direct_reference().struct_field().field(); |
| 483 | + auto col_pos_ref = *field_index; |
489 | 484 | if (col_pos_ref < left_header.columns())
|
490 | 485 | {
|
491 | 486 | left_pos = col_pos_ref;
|
@@ -550,8 +545,7 @@ bool JoinRelParser::applyJoinFilter(
|
550 | 545 | std::vector<substrait::Expression> exprs;
|
551 | 546 | for (size_t i = 0; i < header.columns(); ++i)
|
552 | 547 | {
|
553 |
| - substrait::Expression expr; |
554 |
| - expr.mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(i); |
| 548 | + substrait::Expression expr = SubstraitParserUtils::buildStructFieldExpression(i); |
555 | 549 | exprs.emplace_back(expr);
|
556 | 550 | }
|
557 | 551 | return exprs;
|
@@ -680,23 +674,17 @@ bool JoinRelParser::couldRewriteToMultiJoinOnClauses(
|
680 | 674 | dfs_visit_and_expr(and_exprs, args[1].value());
|
681 | 675 | };
|
682 | 676 |
|
683 |
| - auto get_field_ref = [](const substrait::Expression & e) -> std::optional<Int32> |
684 |
| - { |
685 |
| - if (e.has_selection() && e.selection().has_direct_reference() && e.selection().direct_reference().has_struct_field()) |
686 |
| - return std::optional<Int32>(e.selection().direct_reference().struct_field().field()); |
687 |
| - return {}; |
688 |
| - }; |
689 | 677 | auto visit_equal_expr = [&](const substrait::Expression & e) -> std::optional<std::pair<String, String>>
|
690 | 678 | {
|
691 | 679 | if (!check_function("equals", e))
|
692 | 680 | return {};
|
693 | 681 | const auto & args = e.scalar_function().arguments();
|
694 |
| - auto l_field_ref = get_field_ref(args[0].value()); |
695 |
| - auto r_field_ref = get_field_ref(args[1].value()); |
| 682 | + auto l_field_ref = SubstraitParserUtils::getStructFieldIndex(args[0].value()); |
| 683 | + auto r_field_ref = SubstraitParserUtils::getStructFieldIndex(args[1].value()); |
696 | 684 | if (!l_field_ref.has_value() || !r_field_ref.has_value())
|
697 | 685 | return {};
|
698 |
| - size_t l_pos = static_cast<size_t>(*l_field_ref); |
699 |
| - size_t r_pos = static_cast<size_t>(*r_field_ref); |
| 686 | + size_t l_pos = *l_field_ref; |
| 687 | + size_t r_pos = *r_field_ref; |
700 | 688 | size_t l_cols = left_header.columns();
|
701 | 689 | size_t total_cols = l_cols + right_header.columns();
|
702 | 690 |
|
|
0 commit comments