Skip to content

Commit 3524973

Browse files
committed
simplify subtrait parsing
1 parent 6ece721 commit 3524973

12 files changed

+110
-73
lines changed

cpp-ch/local-engine/Builder/SerializedPlanBuilder.cpp

-8
Original file line numberDiff line numberDiff line change
@@ -312,14 +312,6 @@ void SerializedPlanBuilder::buildType(const DB::DataTypePtr & ch_type, String &
312312
substrait_type = pb->SerializeAsString();
313313
}
314314

315-
316-
substrait::Expression * selection(int32_t field_id)
317-
{
318-
substrait::Expression * rel = new substrait::Expression();
319-
auto * selection = rel->mutable_selection();
320-
selection->mutable_direct_reference()->mutable_struct_field()->set_field(field_id);
321-
return rel;
322-
}
323315
substrait::Expression * scalarFunction(int32_t id, ExpressionList args)
324316
{
325317
substrait::Expression * rel = new substrait::Expression();

cpp-ch/local-engine/Parser/ExpressionParser.cpp

+6-5
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <Parser/FunctionParser.h>
3636
#include <Parser/ParserContext.h>
3737
#include <Parser/SerializedPlanParser.h>
38+
#include <Parser/SubstraitParserUtils.h>
3839
#include <Parser/TypeParser.h>
3940
#include <Poco/Logger.h>
4041
#include <Common/BlockTypeUtils.h>
@@ -293,10 +294,11 @@ ExpressionParser::NodeRawConstPtr ExpressionParser::parseExpression(ActionsDAG &
293294
}
294295

295296
case substrait::Expression::RexTypeCase::kSelection: {
296-
if (!rel.selection().has_direct_reference() || !rel.selection().direct_reference().has_struct_field())
297+
auto field_index = SubstraitParserUtils::getStructFieldIndex(rel);
298+
if (!field_index)
297299
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Can only have direct struct references in selections");
298300

299-
const auto * field = actions_dag.getInputs()[rel.selection().direct_reference().struct_field().field()];
301+
const auto * field = actions_dag.getInputs()[*field_index];
300302
return field;
301303
}
302304

@@ -521,10 +523,9 @@ ExpressionParser::expressionsToActionsDAG(const std::vector<substrait::Expressio
521523

522524
for (const auto & expr : expressions)
523525
{
524-
if (expr.has_selection())
526+
if (auto field_index = SubstraitParserUtils::getStructFieldIndex(expr))
525527
{
526-
auto position = expr.selection().direct_reference().struct_field().field();
527-
auto col_name = header.getByPosition(position).name;
528+
auto col_name = header.getByPosition(*field_index).name;
528529
const DB::ActionsDAG::Node * field = actions_dag.tryFindInOutputs(col_name);
529530
if (!field)
530531
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Not found {} in actions dag's output", col_name);

cpp-ch/local-engine/Parser/FunctionExecutor.cpp

+2-4
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <Builder/SerializedPlanBuilder.h>
2020
#include <Core/ColumnWithTypeAndName.h>
2121
#include <Common/BlockTypeUtils.h>
22+
#include <Parser/SubstraitParserUtils.h>
2223

2324
namespace DB
2425
{
@@ -57,10 +58,7 @@ void FunctionExecutor::buildExpression()
5758
[&](const auto & ) {
5859
substrait::FunctionArgument argument;
5960
auto * value = argument.mutable_value();
60-
auto * selection = value->mutable_selection();
61-
auto * direct_reference = selection->mutable_direct_reference();
62-
auto * struct_field = direct_reference->mutable_struct_field();
63-
struct_field->set_field(field++);
61+
value->CopyFrom(SubstraitParserUtils::buildStructFieldExpression(field++));
6462

6563
arguments->Add(std::move(argument));
6664
});

cpp-ch/local-engine/Parser/RelParsers/AggregateRelParser.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include <Operator/StreamingAggregatingStep.h>
2929
#include <Parser/AdvancedParametersParseUtil.h>
3030
#include <Parser/AggregateFunctionParser.h>
31+
#include <Parser/SubstraitParserUtils.h>
3132
#include <Processors/QueryPlan/AggregatingStep.h>
3233
#include <Processors/QueryPlan/ExpressionStep.h>
3334
#include <Processors/QueryPlan/MergingAggregatedStep.h>
@@ -186,10 +187,13 @@ void AggregateRelParser::setup(DB::QueryPlanPtr query_plan, const substrait::Rel
186187
if (aggregate_rel->groupings_size() == 1)
187188
{
188189
for (const auto & expr : aggregate_rel->groupings(0).grouping_expressions())
189-
if (expr.has_selection() && expr.selection().has_direct_reference())
190-
grouping_keys.push_back(input_header.getByPosition(expr.selection().direct_reference().struct_field().field()).name);
190+
{
191+
auto field_index = SubstraitParserUtils::getStructFieldIndex(expr);
192+
if (!field_index)
193+
grouping_keys.push_back(input_header.getByPosition(*field_index).name);
191194
else
192195
throw Exception(ErrorCodes::BAD_ARGUMENTS, "unsupported group expression: {}", expr.DebugString());
196+
}
193197
}
194198
else if (aggregate_rel->groupings_size() != 0)
195199
{

cpp-ch/local-engine/Parser/RelParsers/ExpandRelParser.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <Operator/AdvancedExpandStep.h>
2424
#include <Operator/ExpandStep.h>
2525
#include <Parser/RelParsers/RelParser.h>
26+
#include <Parser/SubstraitParserUtils.h>
2627
#include <Processors/QueryPlan/QueryPlan.h>
2728
#include <Common/logger_useful.h>
2829

@@ -83,18 +84,17 @@ ExpandField ExpandRelParser::buildExpandField(const DB::Block & header, const su
8384
for (int i = 0; i < expand_col_size; ++i)
8485
{
8586
const auto & project_expr = projections.switching_field().duplicates(i);
86-
if (project_expr.has_selection())
87+
if (auto field_index = SubstraitParserUtils::getStructFieldIndex(project_expr))
8788
{
88-
auto field = project_expr.selection().direct_reference().struct_field().field();
8989
kinds.push_back(ExpandFieldKind::EXPAND_FIELD_KIND_SELECTION);
90-
fields.push_back(field);
91-
if (field >= header.columns())
90+
fields.push_back(*field_index);
91+
if (*field_index >= header.columns())
9292
{
9393
throw DB::Exception(
94-
DB::ErrorCodes::LOGICAL_ERROR, "Field index out of range: {}, header: {}", field, header.dumpStructure());
94+
DB::ErrorCodes::LOGICAL_ERROR, "Field index out of range: {}, header: {}", *field_index, header.dumpStructure());
9595
}
96-
updateType(types[i], header.getByPosition(field).type);
97-
const auto & name = header.getByPosition(field).name;
96+
updateType(types[i], header.getByPosition(*field_index).type);
97+
const auto & name = header.getByPosition(*field_index).name;
9898
if (names[i].empty())
9999
{
100100
if (distinct_names.contains(name))

cpp-ch/local-engine/Parser/RelParsers/GroupLimitRelParser.cpp

+5-4
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
#include <Parser/AdvancedParametersParseUtil.h>
3737
#include <Parser/RelParsers/SortParsingUtils.h>
3838
#include <Parser/RelParsers/SortRelParser.h>
39+
#include <Parser/SubstraitParserUtils.h>
3940
#include <Processors/QueryPlan/ExpressionStep.h>
4041
#include <Processors/QueryPlan/FilterStep.h>
4142
#include <Processors/QueryPlan/QueryPlan.h>
@@ -100,8 +101,8 @@ static std::vector<size_t> parsePartitionFields(const google::protobuf::Repeated
100101
{
101102
std::vector<size_t> fields;
102103
for (const auto & expr : expressions)
103-
if (expr.has_selection())
104-
fields.push_back(static_cast<size_t>(expr.selection().direct_reference().struct_field().field()));
104+
if (auto field_index = SubstraitParserUtils::getStructFieldIndex(expr))
105+
fields.push_back(*field_index);
105106
else if (expr.has_literal())
106107
continue;
107108
else
@@ -115,8 +116,8 @@ std::vector<size_t> parseSortFields(const google::protobuf::RepeatedPtrField<sub
115116
for (const auto sort_field : sort_fields)
116117
if (sort_field.expr().has_literal())
117118
continue;
118-
else if (sort_field.expr().has_selection())
119-
fields.push_back(static_cast<size_t>(sort_field.expr().selection().direct_reference().struct_field().field()));
119+
else if (auto field_index = SubstraitParserUtils::getStructFieldIndex(sort_field.expr()))
120+
fields.push_back(*field_index);
120121
else
121122
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknown expression: {}", sort_field.expr().DebugString());
122123
return fields;

cpp-ch/local-engine/Parser/RelParsers/JoinRelParser.cpp

+15-27
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include <Parser/AdvancedParametersParseUtil.h>
3232
#include <Parser/ExpressionParser.h>
3333
#include <Parsers/ASTIdentifier.h>
34+
#include <Parser/SubstraitParserUtils.h>
3435
#include <Processors/QueryPlan/ExpressionStep.h>
3536
#include <Processors/QueryPlan/FilterStep.h>
3637
#include <Processors/QueryPlan/JoinStep.h>
@@ -114,10 +115,9 @@ std::unordered_set<DB::JoinTableSide> JoinRelParser::extractTableSidesFromExpres
114115
table_sides.insert(table_sides_from_arg.begin(), table_sides_from_arg.end());
115116
}
116117
}
117-
else if (expr.has_selection() && expr.selection().has_direct_reference() && expr.selection().direct_reference().has_struct_field())
118+
else if (auto field = SubstraitParserUtils::getStructFieldIndex(expr))
118119
{
119-
auto pos = expr.selection().direct_reference().struct_field().field();
120-
if (pos < left_header.columns())
120+
if (*field < left_header.columns())
121121
table_sides.insert(DB::JoinTableSide::Left);
122122
else
123123
table_sides.insert(DB::JoinTableSide::Right);
@@ -272,15 +272,10 @@ DB::QueryPlanPtr JoinRelParser::parseJoin(const substrait::JoinRel & join, DB::Q
272272
auto input_header = left->getCurrentHeader();
273273
DB::ActionsDAG filter_is_not_null_dag{input_header.getColumnsWithTypeAndName()};
274274
// when is_null_aware_anti_join is true, there is only one join key
275-
const auto * key_field = filter_is_not_null_dag.getInputs()[join.expression()
276-
.scalar_function()
277-
.arguments()
278-
.at(0)
279-
.value()
280-
.selection()
281-
.direct_reference()
282-
.struct_field()
283-
.field()];
275+
auto field_index = SubstraitParserUtils::getStructFieldIndex(join.expression().scalar_function().arguments(0).value());
276+
if (!field_index)
277+
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "The join key is not found in the expression.");
278+
const auto * key_field = filter_is_not_null_dag.getInputs()[*field_index];
284279

285280
auto result_node = filter_is_not_null_dag.tryFindInOutputs(key_field->result_name);
286281
// add a function isNotNull to filter the null key on the left side
@@ -480,12 +475,12 @@ void JoinRelParser::collectJoinKeys(
480475
size_t left_pos = 0, right_pos = 0;
481476
for (const auto & arg : current_expr->scalar_function().arguments())
482477
{
483-
if (!arg.value().has_selection() || !arg.value().selection().has_direct_reference()
484-
|| !arg.value().selection().direct_reference().has_struct_field())
478+
auto field_index = SubstraitParserUtils::getStructFieldIndex(arg.value());
479+
if (!field_index)
485480
{
486481
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "A column reference is expected");
487482
}
488-
auto col_pos_ref = arg.value().selection().direct_reference().struct_field().field();
483+
auto col_pos_ref = *field_index;
489484
if (col_pos_ref < left_header.columns())
490485
{
491486
left_pos = col_pos_ref;
@@ -550,8 +545,7 @@ bool JoinRelParser::applyJoinFilter(
550545
std::vector<substrait::Expression> exprs;
551546
for (size_t i = 0; i < header.columns(); ++i)
552547
{
553-
substrait::Expression expr;
554-
expr.mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(i);
548+
substrait::Expression expr = SubstraitParserUtils::buildStructFieldExpression(i);
555549
exprs.emplace_back(expr);
556550
}
557551
return exprs;
@@ -680,23 +674,17 @@ bool JoinRelParser::couldRewriteToMultiJoinOnClauses(
680674
dfs_visit_and_expr(and_exprs, args[1].value());
681675
};
682676

683-
auto get_field_ref = [](const substrait::Expression & e) -> std::optional<Int32>
684-
{
685-
if (e.has_selection() && e.selection().has_direct_reference() && e.selection().direct_reference().has_struct_field())
686-
return std::optional<Int32>(e.selection().direct_reference().struct_field().field());
687-
return {};
688-
};
689677
auto visit_equal_expr = [&](const substrait::Expression & e) -> std::optional<std::pair<String, String>>
690678
{
691679
if (!check_function("equals", e))
692680
return {};
693681
const auto & args = e.scalar_function().arguments();
694-
auto l_field_ref = get_field_ref(args[0].value());
695-
auto r_field_ref = get_field_ref(args[1].value());
682+
auto l_field_ref = SubstraitParserUtils::getStructFieldIndex(args[0].value());
683+
auto r_field_ref = SubstraitParserUtils::getStructFieldIndex(args[1].value());
696684
if (!l_field_ref.has_value() || !r_field_ref.has_value())
697685
return {};
698-
size_t l_pos = static_cast<size_t>(*l_field_ref);
699-
size_t r_pos = static_cast<size_t>(*r_field_ref);
686+
size_t l_pos = *l_field_ref;
687+
size_t r_pos = *r_field_ref;
700688
size_t l_cols = left_header.columns();
701689
size_t total_cols = l_cols + right_header.columns();
702690

cpp-ch/local-engine/Parser/RelParsers/MergeTreeRelParser.cpp

+5-3
Original file line numberDiff line numberDiff line change
@@ -317,9 +317,11 @@ void MergeTreeRelParser::collectColumns(const substrait::Expression & rel, NameS
317317
}
318318

319319
case substrait::Expression::RexTypeCase::kSelection: {
320-
const size_t idx = rel.selection().direct_reference().struct_field().field();
321-
if (const Names names = block.getNames(); names.size() > idx)
322-
columns.insert(names[idx]);
320+
auto idx = SubstraitParserUtils::getStructFieldIndex(rel);
321+
if (!idx)
322+
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Selection node must have direct reference.");
323+
if (const Names names = block.getNames(); names.size() > *idx)
324+
columns.insert(names[*idx]);
323325

324326
return;
325327
}

cpp-ch/local-engine/Parser/RelParsers/SortParsingUtils.cpp

+12-10
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <Poco/Logger.h>
2121
#include <Common/Exception.h>
2222
#include <Common/logger_useful.h>
23+
#include <Parser/SubstraitParserUtils.h>
2324

2425
namespace DB::ErrorCodes
2526
{
@@ -33,16 +34,18 @@ DB::SortDescription parseSortFields(const DB::Block & header, const google::prot
3334
{
3435
DB::SortDescription description;
3536
for (const auto & expr : expressions)
36-
if (expr.has_selection())
37+
{
38+
auto field_index = SubstraitParserUtils::getStructFieldIndex(expr);
39+
if (field_index)
3740
{
38-
auto pos = expr.selection().direct_reference().struct_field().field();
39-
const auto & col_name = header.getByPosition(pos).name;
41+
const auto & col_name = header.getByPosition(*field_index).name;
4042
description.push_back(DB::SortColumnDescription(col_name, 1, -1));
4143
}
4244
else if (expr.has_literal())
4345
continue;
4446
else
4547
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow expression as sort field: {}", expr.DebugString());
48+
}
4649
return description;
4750
}
4851

@@ -58,17 +61,16 @@ DB::SortDescription parseSortFields(const DB::Block & header, const google::prot
5861
if (sort_field.expr().has_literal())
5962
continue;
6063

61-
if (!sort_field.expr().has_selection() || !sort_field.expr().selection().has_direct_reference()
62-
|| !sort_field.expr().selection().direct_reference().has_struct_field())
64+
auto field_index = SubstraitParserUtils::getStructFieldIndex(sort_field.expr());
65+
if(!field_index)
6366
{
6467
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsupport sort field");
6568
}
66-
auto field_pos = sort_field.expr().selection().direct_reference().struct_field().field();
6769

6870
auto direction_iter = direction_map.find(sort_field.direction());
6971
if (direction_iter == direction_map.end())
7072
throw DB::Exception(DB::ErrorCodes::LOGICAL_ERROR, "Unsuppor sort direction: {}", sort_field.direction());
71-
const auto & col_name = header.getByPosition(field_pos).name;
73+
const auto & col_name = header.getByPosition(*field_index).name;
7274
sort_descr.emplace_back(col_name, direction_iter->second.first, direction_iter->second.second);
7375
}
7476
return sort_descr;
@@ -86,13 +88,13 @@ buildSQLLikeSortDescription(const DB::Block & header, const google::protobuf::Re
8688
auto it = order_directions.find(sort_field.direction());
8789
if (it == order_directions.end())
8890
throw DB::Exception(DB::ErrorCodes::BAD_ARGUMENTS, "Unknow sort direction: {}", sort_field.direction());
89-
if (!sort_field.expr().has_selection())
91+
auto field_index = SubstraitParserUtils::getStructFieldIndex(sort_field.expr());
92+
if (!field_index)
9093
{
9194
throw DB::Exception(
9295
DB::ErrorCodes::BAD_ARGUMENTS, "Sort field must be a column reference. but got {}", sort_field.DebugString());
9396
}
94-
auto ref = sort_field.expr().selection().direct_reference().struct_field().field();
95-
const auto & col_name = header.getByPosition(ref).name;
97+
const auto & col_name = header.getByPosition(*field_index).name;
9698
if (n)
9799
ostr << String(",");
98100
// the col_name may contain '#' which can may ch fail to parse.

cpp-ch/local-engine/Parser/SerializedPlanParser.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -416,10 +416,9 @@ void NonNullableColumnsResolver::visitNonNullable(const substrait::Expression &
416416
visitNonNullable(scalar_function.arguments(1).value());
417417
}
418418
}
419-
else if (expr.has_selection())
419+
else if (auto field_index = SubstraitParserUtils::getStructFieldIndex(expr))
420420
{
421-
const auto & selection = expr.selection();
422-
auto column_pos = selection.direct_reference().struct_field().field();
421+
const auto & column_pos = *field_index;
423422
auto column_name = header.getByPosition(column_pos).name;
424423
collected_columns.insert(column_name);
425424
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#include <Parser/SubstraitParserUtils.h>
18+
#include "substrait/algebra.pb.h"
19+
namespace local_engine::SubstraitParserUtils
20+
{
21+
std::optional<size_t> getStructFieldIndex(const substrait::Expression & e)
22+
{
23+
if (!e.has_selection())
24+
return {};
25+
const auto & select = e.selection();
26+
if (!select.has_direct_reference())
27+
return {};
28+
const auto & ref = select.direct_reference();
29+
if (!ref.has_struct_field())
30+
return {};
31+
return ref.struct_field().field();
32+
}
33+
34+
substrait::Expression buildStructFieldExpression(size_t index)
35+
{
36+
substrait::Expression e;
37+
e.mutable_selection()->mutable_direct_reference()->mutable_struct_field()->set_field(index);
38+
return e;
39+
}
40+
41+
}

0 commit comments

Comments
 (0)