|
13 | 13 | * See the License for the specific language governing permissions and
|
14 | 14 | * limitations under the License.
|
15 | 15 | */
|
| 16 | + |
| 17 | +#include <optional> |
| 18 | +#include <set> |
| 19 | +#include <unordered_map> |
| 20 | + |
16 | 21 | #include "velox/exec/fuzzer/DuckQueryRunner.h"
|
17 | 22 | #include "velox/exec/fuzzer/ToSQLUtil.h"
|
18 | 23 | #include "velox/exec/tests/utils/QueryAssertions.h"
|
@@ -102,23 +107,39 @@ DuckQueryRunner::aggregationFunctionDataSpecs() const {
|
102 | 107 | return kAggregationFunctionDataSpecs;
|
103 | 108 | }
|
104 | 109 |
|
105 |
| -std::multiset<std::vector<velox::variant>> DuckQueryRunner::execute( |
106 |
| - const std::string& sql, |
107 |
| - const std::vector<RowVectorPtr>& input, |
108 |
| - const RowTypePtr& resultType) { |
109 |
| - DuckDbQueryRunner queryRunner; |
110 |
| - queryRunner.createTable("tmp", input); |
111 |
| - return queryRunner.execute(sql, resultType); |
| 110 | +std::pair< |
| 111 | + std::optional<std::multiset<std::vector<velox::variant>>>, |
| 112 | + ReferenceQueryErrorCode> |
| 113 | +DuckQueryRunner::execute(const core::PlanNodePtr& plan) { |
| 114 | + if (std::optional<std::string> sql = toSql(plan)) { |
| 115 | + try { |
| 116 | + DuckDbQueryRunner queryRunner; |
| 117 | + std::unordered_map<std::string, std::vector<RowVectorPtr>> inputMap = |
| 118 | + getAllTables(plan); |
| 119 | + for (const auto& [tableName, input] : inputMap) { |
| 120 | + queryRunner.createTable(tableName, input); |
| 121 | + } |
| 122 | + return std::make_pair( |
| 123 | + queryRunner.execute(*sql, plan->outputType()), |
| 124 | + ReferenceQueryErrorCode::kSuccess); |
| 125 | + } catch (...) { |
| 126 | + LOG(WARNING) << "Query failed in DuckDB"; |
| 127 | + return std::make_pair( |
| 128 | + std::nullopt, ReferenceQueryErrorCode::kReferenceQueryFail); |
| 129 | + } |
| 130 | + } |
| 131 | + |
| 132 | + LOG(INFO) << "Query not supported in DuckDB"; |
| 133 | + return std::make_pair( |
| 134 | + std::nullopt, ReferenceQueryErrorCode::kReferenceQueryUnsupported); |
112 | 135 | }
|
113 | 136 |
|
114 | 137 | std::multiset<std::vector<velox::variant>> DuckQueryRunner::execute(
|
115 | 138 | const std::string& sql,
|
116 |
| - const std::vector<RowVectorPtr>& probeInput, |
117 |
| - const std::vector<RowVectorPtr>& buildInput, |
| 139 | + const std::vector<RowVectorPtr>& input, |
118 | 140 | const RowTypePtr& resultType) {
|
119 | 141 | DuckDbQueryRunner queryRunner;
|
120 |
| - queryRunner.createTable("t", probeInput); |
121 |
| - queryRunner.createTable("u", buildInput); |
| 142 | + queryRunner.createTable("tmp", input); |
122 | 143 | return queryRunner.execute(sql, resultType);
|
123 | 144 | }
|
124 | 145 |
|
@@ -164,6 +185,11 @@ std::optional<std::string> DuckQueryRunner::toSql(
|
164 | 185 | return toSql(joinNode);
|
165 | 186 | }
|
166 | 187 |
|
| 188 | + if (const auto valuesNode = |
| 189 | + std::dynamic_pointer_cast<const core::ValuesNode>(plan)) { |
| 190 | + return toSql(valuesNode); |
| 191 | + } |
| 192 | + |
167 | 193 | VELOX_NYI();
|
168 | 194 | }
|
169 | 195 |
|
@@ -340,137 +366,4 @@ std::optional<std::string> DuckQueryRunner::toSql(
|
340 | 366 |
|
341 | 367 | return sql.str();
|
342 | 368 | }
|
343 |
| - |
344 |
| -std::optional<std::string> DuckQueryRunner::toSql( |
345 |
| - const std::shared_ptr<const core::HashJoinNode>& joinNode) { |
346 |
| - const auto& joinKeysToSql = [](auto keys) { |
347 |
| - std::stringstream out; |
348 |
| - for (auto i = 0; i < keys.size(); ++i) { |
349 |
| - if (i > 0) { |
350 |
| - out << ", "; |
351 |
| - } |
352 |
| - out << keys[i]->name(); |
353 |
| - } |
354 |
| - return out.str(); |
355 |
| - }; |
356 |
| - |
357 |
| - const auto filterToSql = [](core::TypedExprPtr filter) { |
358 |
| - auto call = std::dynamic_pointer_cast<const core::CallTypedExpr>(filter); |
359 |
| - return toCallSql(call); |
360 |
| - }; |
361 |
| - |
362 |
| - const auto& joinConditionAsSql = [&](auto joinNode) { |
363 |
| - std::stringstream out; |
364 |
| - for (auto i = 0; i < joinNode->leftKeys().size(); ++i) { |
365 |
| - if (i > 0) { |
366 |
| - out << " AND "; |
367 |
| - } |
368 |
| - out << joinNode->leftKeys()[i]->name() << " = " |
369 |
| - << joinNode->rightKeys()[i]->name(); |
370 |
| - } |
371 |
| - if (joinNode->filter()) { |
372 |
| - out << " AND " << filterToSql(joinNode->filter()); |
373 |
| - } |
374 |
| - return out.str(); |
375 |
| - }; |
376 |
| - |
377 |
| - const auto& outputNames = joinNode->outputType()->names(); |
378 |
| - |
379 |
| - std::stringstream sql; |
380 |
| - if (joinNode->isLeftSemiProjectJoin()) { |
381 |
| - sql << "SELECT " |
382 |
| - << folly::join(", ", outputNames.begin(), --outputNames.end()); |
383 |
| - } else { |
384 |
| - sql << "SELECT " << folly::join(", ", outputNames); |
385 |
| - } |
386 |
| - |
387 |
| - switch (joinNode->joinType()) { |
388 |
| - case core::JoinType::kInner: |
389 |
| - sql << " FROM t INNER JOIN u ON " << joinConditionAsSql(joinNode); |
390 |
| - break; |
391 |
| - case core::JoinType::kLeft: |
392 |
| - sql << " FROM t LEFT JOIN u ON " << joinConditionAsSql(joinNode); |
393 |
| - break; |
394 |
| - case core::JoinType::kFull: |
395 |
| - sql << " FROM t FULL OUTER JOIN u ON " << joinConditionAsSql(joinNode); |
396 |
| - break; |
397 |
| - case core::JoinType::kLeftSemiFilter: |
398 |
| - // Multiple columns returned by a scalar subquery is not supported in |
399 |
| - // DuckDB. A scalar subquery expression is a subquery that returns one |
400 |
| - // result row from exactly one column for every input row. |
401 |
| - if (joinNode->leftKeys().size() > 1) { |
402 |
| - return std::nullopt; |
403 |
| - } |
404 |
| - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) |
405 |
| - << " IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) |
406 |
| - << " FROM u"; |
407 |
| - if (joinNode->filter()) { |
408 |
| - sql << " WHERE " << filterToSql(joinNode->filter()); |
409 |
| - } |
410 |
| - sql << ")"; |
411 |
| - break; |
412 |
| - case core::JoinType::kLeftSemiProject: |
413 |
| - if (joinNode->isNullAware()) { |
414 |
| - sql << ", " << joinKeysToSql(joinNode->leftKeys()) << " IN (SELECT " |
415 |
| - << joinKeysToSql(joinNode->rightKeys()) << " FROM u"; |
416 |
| - if (joinNode->filter()) { |
417 |
| - sql << " WHERE " << filterToSql(joinNode->filter()); |
418 |
| - } |
419 |
| - sql << ") FROM t"; |
420 |
| - } else { |
421 |
| - sql << ", EXISTS (SELECT * FROM u WHERE " |
422 |
| - << joinConditionAsSql(joinNode); |
423 |
| - sql << ") FROM t"; |
424 |
| - } |
425 |
| - break; |
426 |
| - case core::JoinType::kAnti: |
427 |
| - if (joinNode->isNullAware()) { |
428 |
| - sql << " FROM t WHERE " << joinKeysToSql(joinNode->leftKeys()) |
429 |
| - << " NOT IN (SELECT " << joinKeysToSql(joinNode->rightKeys()) |
430 |
| - << " FROM u"; |
431 |
| - if (joinNode->filter()) { |
432 |
| - sql << " WHERE " << filterToSql(joinNode->filter()); |
433 |
| - } |
434 |
| - sql << ")"; |
435 |
| - } else { |
436 |
| - sql << " FROM t WHERE NOT EXISTS (SELECT * FROM u WHERE " |
437 |
| - << joinConditionAsSql(joinNode); |
438 |
| - sql << ")"; |
439 |
| - } |
440 |
| - break; |
441 |
| - default: |
442 |
| - VELOX_UNREACHABLE( |
443 |
| - "Unknown join type: {}", static_cast<int>(joinNode->joinType())); |
444 |
| - } |
445 |
| - |
446 |
| - return sql.str(); |
447 |
| -} |
448 |
| - |
449 |
| -std::optional<std::string> DuckQueryRunner::toSql( |
450 |
| - const std::shared_ptr<const core::NestedLoopJoinNode>& joinNode) { |
451 |
| - std::stringstream sql; |
452 |
| - sql << "SELECT " << folly::join(", ", joinNode->outputType()->names()); |
453 |
| - |
454 |
| - // Nested loop join without filter. |
455 |
| - VELOX_CHECK( |
456 |
| - joinNode->joinCondition() == nullptr, |
457 |
| - "This code path should be called only for nested loop join without filter"); |
458 |
| - const std::string joinCondition{"(1 = 1)"}; |
459 |
| - switch (joinNode->joinType()) { |
460 |
| - case core::JoinType::kInner: |
461 |
| - sql << " FROM t INNER JOIN u ON " << joinCondition; |
462 |
| - break; |
463 |
| - case core::JoinType::kLeft: |
464 |
| - sql << " FROM t LEFT JOIN u ON " << joinCondition; |
465 |
| - break; |
466 |
| - case core::JoinType::kFull: |
467 |
| - sql << " FROM t FULL OUTER JOIN u ON " << joinCondition; |
468 |
| - break; |
469 |
| - default: |
470 |
| - VELOX_UNREACHABLE( |
471 |
| - "Unknown join type: {}", static_cast<int>(joinNode->joinType())); |
472 |
| - } |
473 |
| - |
474 |
| - return sql.str(); |
475 |
| -} |
476 | 369 | } // namespace facebook::velox::exec::test
|
0 commit comments