Skip to content

Commit 26cde39

Browse files
authored
[GLUTEN-8921][GLUTEN-8922][CH] Fix checkDecimalOverflowSparkOrNull and lead function (#8929)
1 parent 0257def commit 26cde39

File tree

3 files changed

+53
-10
lines changed

3 files changed

+53
-10
lines changed

backends-clickhouse/src/test/scala/org/apache/gluten/execution/compatibility/GlutenClickhouseFunctionSuite.scala

+34
Original file line numberDiff line numberDiff line change
@@ -442,4 +442,38 @@ class GlutenClickhouseFunctionSuite extends GlutenClickHouseTPCHAbstractSuite {
442442
}
443443
}
444444

445+
test("GLUTEN-8921: Type mismatch at checkDecimalOverflowSparkOrNull") {
446+
compareResultsAgainstVanillaSpark(
447+
"""
448+
|select l_shipdate, avg(l_quantity), count(0) over() COU,
449+
|SUM(-1.1) over() SU, AVG(-2) over() AV,
450+
|max(-1.1) over() MA, min(-3) over() MI
451+
|from lineitem
452+
|where l_shipdate <= date'1998-09-02'
453+
|group by l_shipdate
454+
|order by l_shipdate
455+
""".stripMargin,
456+
true,
457+
{ _ => }
458+
)
459+
}
460+
461+
test("GLUTEN-8922: Incorrect result in lead function with constant col") {
462+
compareResultsAgainstVanillaSpark(
463+
"""
464+
|select l_shipdate,
465+
|FIRST_VALUE(-2) over() FI,
466+
|LAST_VALUE(-2) over() LA,
467+
|lag(-2) over(order by l_shipdate) lag0,
468+
|lead(-2) over(order by l_shipdate) lead0
469+
|from lineitem
470+
|where l_shipdate <= date'1998-09-02'
471+
|group by l_shipdate
472+
|order by l_shipdate
473+
""".stripMargin,
474+
true,
475+
{ _ => }
476+
)
477+
}
478+
445479
}

cpp-ch/local-engine/Functions/SparkFunctionCheckDecimalOverflow.cpp

+15-6
Original file line numberDiff line numberDiff line change
@@ -100,14 +100,20 @@ class FunctionCheckDecimalOverflow : public IFunction
100100
UInt32 precision = extractArgument(arguments[1]);
101101
UInt32 scale = extractArgument(arguments[2]);
102102
auto return_type = createDecimal<DataTypeDecimal>(precision, scale);
103-
if constexpr (exception_mode == CheckExceptionMode::Null)
104-
{
105-
if (!arguments[0].type->isNullable())
106-
return std::make_shared<DataTypeNullable>(return_type);
107-
}
103+
if (isReturnTypeNullable(arguments[0]))
104+
return std::make_shared<DataTypeNullable>(return_type);
108105
return return_type;
109106
}
110107

108+
bool isReturnTypeNullable(const ColumnWithTypeAndName & arg) const
109+
{
110+
if constexpr (exception_mode == CheckExceptionMode::Null)
111+
return true;
112+
if (arg.type->isNullable())
113+
return true;
114+
return false;
115+
}
116+
111117
ColumnPtr executeImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr &, size_t input_rows_count) const override
112118
{
113119
UInt32 to_precision = extractArgument(arguments[1]);
@@ -133,7 +139,10 @@ class FunctionCheckDecimalOverflow : public IFunction
133139
auto from_scale = getDecimalScale(*src_col.type);
134140
if (from_precision == to_precision && from_scale == to_scale)
135141
{
136-
dst_col = src_col.column;
142+
if (isReturnTypeNullable(arguments[0]))
143+
dst_col = makeNullable(src_col.column);
144+
else
145+
dst_col = src_col.column;
137146
return true;
138147
}
139148
}

cpp-ch/local-engine/Parser/aggregate_function_parser/LeadLagParser.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -33,15 +33,15 @@ LeadParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::Act
3333
/// The 3rd arg is default value
3434
/// when it is set to null, the 1st arg must be nullable
3535
const auto & arg2 = func_info.arguments[2].value();
36-
const auto * arg0_col = actions_dag.getInputs()[arg0.selection().direct_reference().struct_field().field()];
36+
const auto * arg0_col = parseExpression(actions_dag, arg0);
3737
auto arg0_col_name = arg0_col->result_name;
3838
auto arg0_col_type= arg0_col->result_type;
3939
const DB::ActionsDAG::Node * node = nullptr;
4040
if (arg2.has_literal() && arg2.literal().has_null() && !arg0_col_type->isNullable())
4141
{
4242
node = ActionsDAGUtil::convertNodeType(
4343
actions_dag,
44-
&actions_dag.findInOutputs(arg0_col_name),
44+
arg0_col,
4545
DB::makeNullable(arg0_col_type),
4646
arg0_col_name);
4747
actions_dag.addOrReplaceInOutputs(*node);
@@ -76,15 +76,15 @@ LagParser::parseFunctionArguments(const CommonFunctionInfo & func_info, DB::Acti
7676
/// The 3rd arg is default value
7777
/// when it is set to null, the 1st arg must be nullable
7878
const auto & arg2 = func_info.arguments[2].value();
79-
const auto * arg0_col = actions_dag.getInputs()[arg0.selection().direct_reference().struct_field().field()];
79+
const auto * arg0_col = parseExpression(actions_dag, arg0);
8080
auto arg0_col_name = arg0_col->result_name;
8181
auto arg0_col_type = arg0_col->result_type;
8282
const DB::ActionsDAG::Node * node = nullptr;
8383
if (arg2.has_literal() && arg2.literal().has_null() && !arg0_col->result_type->isNullable())
8484
{
8585
node = ActionsDAGUtil::convertNodeType(
8686
actions_dag,
87-
&actions_dag.findInOutputs(arg0_col_name),
87+
arg0_col,
8888
makeNullable(arg0_col_type),
8989
arg0_col_name);
9090
actions_dag.addOrReplaceInOutputs(*node);

0 commit comments

Comments
 (0)