diff --git a/c/driver/snowflake/snowflake_test.cc b/c/driver/snowflake/snowflake_test.cc index cdd92e2c71..112d275c53 100644 --- a/c/driver/snowflake/snowflake_test.cc +++ b/c/driver/snowflake/snowflake_test.cc @@ -47,6 +47,10 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { AdbcStatusCode SetupDatabase(struct AdbcDatabase* database, struct AdbcError* error) const override { EXPECT_THAT(AdbcDatabaseSetOption(database, "uri", uri_, error), IsOkStatus(error)); + EXPECT_THAT(AdbcDatabaseSetOption( + database, "adbc.snowflake.sql.client_option.use_high_precision", + "false", error), + IsOkStatus(error)); return ADBC_STATUS_OK; } @@ -119,6 +123,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks { bool supports_metadata_current_db_schema() const override { return false; } bool supports_partitioned_data() const override { return false; } bool supports_dynamic_parameter_binding() const override { return false; } + bool supports_error_on_incompatible_schema() const override { return false; } bool ddl_implicit_commit_txn() const override { return true; } std::string db_schema() const override { return "ADBC_TESTING"; } @@ -204,7 +209,7 @@ class SnowflakeStatementTest : public ::testing::Test, expected = {std::nullopt, -42, 0, 42}; break; case NANOARROW_TIME_UNIT_MILLI: - expected = {std::nullopt, -42000, 0, 42000}; + expected = {std::nullopt, -42, 0, 42}; break; case NANOARROW_TIME_UNIT_MICRO: expected = {std::nullopt, -42, 0, 42}; diff --git a/c/validation/adbc_validation.h b/c/validation/adbc_validation.h index fcb4a5c286..6c59d95e09 100644 --- a/c/validation/adbc_validation.h +++ b/c/validation/adbc_validation.h @@ -220,6 +220,10 @@ class DriverQuirks { /// \brief Whether we can get statistics virtual bool supports_statistics() const { return false; } + /// \brief Whether ingest errors on an incompatible schema or simply performs + /// column matching. + virtual bool supports_error_on_incompatible_schema() const { return true; } + /// \brief Default catalog to use for tests virtual std::string catalog() const { return ""; } diff --git a/c/validation/adbc_validation_statement.cc b/c/validation/adbc_validation_statement.cc index f55ca8c8c5..9a7fde5c58 100644 --- a/c/validation/adbc_validation_statement.cc +++ b/c/validation/adbc_validation_statement.cc @@ -487,7 +487,7 @@ void StatementTest::TestSqlIngestInterval() { void StatementTest::TestSqlIngestStringDictionary() { ASSERT_NO_FATAL_FAILURE(TestSqlIngestType( - NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", "例"}, + NANOARROW_TYPE_STRING, {"", "", "1234", "例"}, /*dictionary_encode*/ true)); } @@ -865,6 +865,8 @@ void StatementTest::TestSqlIngestErrors() { ::testing::Not(IsOkStatus(&error))); if (error.release) error.release(&error); + if (!quirks()->supports_error_on_incompatible_schema()) { return; } + // ...then try to append an incompatible schema ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64}, {"coltwo", NANOARROW_TYPE_INT64}}), @@ -2212,7 +2214,7 @@ void StatementTest::TestSqlQueryInsertRollback() { ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); ASSERT_THAT( - AdbcStatementSetSqlQuery(&statement, "CREATE TABLE rollbacktest (a INT)", &error), + AdbcStatementSetSqlQuery(&statement, "CREATE TABLE \"rollbacktest\" (a INT)", &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), IsOkStatus(&error)); @@ -2220,7 +2222,7 @@ void StatementTest::TestSqlQueryInsertRollback() { ASSERT_THAT(AdbcConnectionCommit(&connection, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, - "INSERT INTO rollbacktest (a) VALUES (1)", &error), + "INSERT INTO \"rollbacktest\" (a) VALUES (1)", &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), IsOkStatus(&error)); @@ -2228,7 +2230,7 @@ void StatementTest::TestSqlQueryInsertRollback() { ASSERT_THAT(AdbcConnectionRollback(&connection, &error), IsOkStatus(&error)); adbc_validation::StreamReader reader; - ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM rollbacktest", &error), + ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM \"rollbacktest\"", &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value, &reader.rows_affected, &error), @@ -2309,20 +2311,20 @@ void StatementTest::TestSqlQueryRowsAffectedDelete() { ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); ASSERT_THAT( - AdbcStatementSetSqlQuery(&statement, "CREATE TABLE delete_test (foo INT)", &error), + AdbcStatementSetSqlQuery(&statement, "CREATE TABLE \"delete_test\" (foo INT)", &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetSqlQuery( &statement, - "INSERT INTO delete_test (foo) VALUES (1), (2), (3), (4), (5)", &error), + "INSERT INTO \"delete_test\" (foo) VALUES (1), (2), (3), (4), (5)", &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, - "DELETE FROM delete_test WHERE foo >= 3", &error), + "DELETE FROM \"delete_test\" WHERE foo >= 3", &error), IsOkStatus(&error)); int64_t rows_affected = 0; @@ -2337,20 +2339,20 @@ void StatementTest::TestSqlQueryRowsAffectedDeleteStream() { ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error)); ASSERT_THAT( - AdbcStatementSetSqlQuery(&statement, "CREATE TABLE delete_test (foo INT)", &error), + AdbcStatementSetSqlQuery(&statement, "CREATE TABLE \"delete_test\" (foo INT)", &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetSqlQuery( &statement, - "INSERT INTO delete_test (foo) VALUES (1), (2), (3), (4), (5)", &error), + "INSERT INTO \"delete_test\" (foo) VALUES (1), (2), (3), (4), (5)", &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error), IsOkStatus(&error)); ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, - "DELETE FROM delete_test WHERE foo >= 3", &error), + "DELETE FROM \"delete_test\" WHERE foo >= 3", &error), IsOkStatus(&error)); adbc_validation::StreamReader reader; diff --git a/go/adbc/driver/snowflake/bulk_ingestion.go b/go/adbc/driver/snowflake/bulk_ingestion.go index 3a0d98bbe6..ca5d33c4ef 100644 --- a/go/adbc/driver/snowflake/bulk_ingestion.go +++ b/go/adbc/driver/snowflake/bulk_ingestion.go @@ -130,6 +130,15 @@ func (st *statement) ingestRecord(ctx context.Context) (nrows int64, err error) st.bound = nil }() + var initialRows int64 + + // Check final row count of target table to get definitive rows affected + initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, strconv.Quote(st.targetTable)) + if err != nil { + st.bound.Release() + return + } + parquetProps, arrowProps := newWriterProps(st.alloc, st.ingestOptions) g := errgroup.Group{} @@ -180,6 +189,7 @@ func (st *statement) ingestRecord(ctx context.Context) (nrows int64, err error) // Check final row count of target table to get definitive rows affected nrows, err = countRowsInTable(ctx, st.cnxn.sqldb, strconv.Quote(st.targetTable)) + nrows = nrows - initialRows return } @@ -193,11 +203,19 @@ func (st *statement) ingestStream(ctx context.Context) (nrows int64, err error) st.streamBind.Release() st.streamBind = nil }() + + var initialRows int64 + // Check final row count of target table to get definitive rows affected + initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, strconv.Quote(st.targetTable)) + if err != nil { + return + } + defer func() { // Always check the resulting row count, even in the case of an error. We may have ingested part of the data. ctx := context.Background() // TODO(joellubi): switch to context.WithoutCancel(ctx) once we're on Go 1.21 n, countErr := countRowsInTable(ctx, st.cnxn.sqldb, st.targetTable) - nrows = n + nrows = n - initialRows // Ingestion, row-count check, or both could have failed // Wrap any failures as ADBC errors diff --git a/go/adbc/driver/snowflake/driver_test.go b/go/adbc/driver/snowflake/driver_test.go index f7b5ac4c13..ce59419fb1 100644 --- a/go/adbc/driver/snowflake/driver_test.go +++ b/go/adbc/driver/snowflake/driver_test.go @@ -143,7 +143,7 @@ func getArr(arr arrow.Array) interface{} { func (s *SnowflakeQuirks) CreateSampleTable(tableName string, r arrow.Record) error { var b strings.Builder b.WriteString("CREATE OR REPLACE TABLE ") - b.WriteString(tableName) + b.WriteString(strconv.Quote(tableName)) b.WriteString(" (") for i := 0; i < int(r.NumCols()); i++ { @@ -164,7 +164,7 @@ func (s *SnowflakeQuirks) CreateSampleTable(tableName string, r arrow.Record) er return err } - insertQuery := "INSERT INTO " + tableName + " VALUES (" + insertQuery := "INSERT INTO " + strconv.Quote(tableName) + " VALUES (" bindings := strings.Repeat("?,", int(r.NumCols())) insertQuery += bindings[:len(bindings)-1] + ")" @@ -184,7 +184,7 @@ func (s *SnowflakeQuirks) DropTable(cnxn adbc.Connection, tblname string) error } defer stmt.Close() - if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` + tblname); err != nil { + if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` + strconv.Quote(tblname)); err != nil { return err } @@ -486,7 +486,7 @@ func (suite *SnowflakeTests) TestSqlIngestRecordAndStreamAreEquivalent() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_bind ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_bind" ORDER BY "col_int64" ASC`)) rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) suite.Require().NoError(err) defer rdr.Release() @@ -509,7 +509,7 @@ func (suite *SnowflakeTests) TestSqlIngestRecordAndStreamAreEquivalent() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_bind_stream ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_bind_stream" ORDER BY "col_int64" ASC`)) rdr, n, err = suite.stmt.ExecuteQuery(suite.ctx) suite.Require().NoError(err) defer rdr.Release() @@ -596,7 +596,7 @@ func (suite *SnowflakeTests) TestSqlIngestRoundtripTypes() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_roundtrip ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_roundtrip" ORDER BY "col_int64" ASC`)) rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) suite.Require().NoError(err) defer rdr.Release() @@ -672,7 +672,7 @@ func (suite *SnowflakeTests) TestSqlIngestTimestampTypes() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_timestamps ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_timestamps" ORDER BY "col_int64" ASC`)) rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) suite.Require().NoError(err) defer rdr.Release() @@ -784,7 +784,7 @@ func (suite *SnowflakeTests) TestSqlIngestDate64Type() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_date64 ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_date64" ORDER BY "col_int64" ASC`)) rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) suite.Require().NoError(err) defer rdr.Release() @@ -877,7 +877,7 @@ func (suite *SnowflakeTests) TestSqlIngestHighPrecision() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_high_precision ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_high_precision" ORDER BY "col_int64" ASC`)) suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueEnabled)) defer func() { suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueDisabled)) @@ -988,7 +988,7 @@ func (suite *SnowflakeTests) TestSqlIngestLowPrecision() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_high_precision ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_high_precision" ORDER BY "col_int64" ASC`)) // OptionUseHighPrecision already disabled rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) suite.Require().NoError(err) @@ -1106,7 +1106,7 @@ func (suite *SnowflakeTests) TestSqlIngestStructType() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_struct ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_struct" ORDER BY "col_int64" ASC`)) rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) suite.Require().NoError(err) defer rdr.Release() @@ -1210,7 +1210,7 @@ func (suite *SnowflakeTests) TestSqlIngestMapType() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_map ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_map" ORDER BY "col_int64" ASC`)) rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) suite.Require().NoError(err) defer rdr.Release() @@ -1299,7 +1299,7 @@ func (suite *SnowflakeTests) TestSqlIngestListType() { suite.Require().NoError(err) suite.EqualValues(3, n) - suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_list ORDER BY \"col_int64\" ASC")) + suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_list" ORDER BY "col_int64" ASC`)) rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx) suite.Require().NoError(err) defer rdr.Release() diff --git a/go/adbc/validation/validation.go b/go/adbc/validation/validation.go index e68ff2a25b..979b6bcae6 100644 --- a/go/adbc/validation/validation.go +++ b/go/adbc/validation/validation.go @@ -800,7 +800,7 @@ func (s *StatementTests) TestSqlIngestInts() { } // use order by clause to ensure we get the same order as the input batch - s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest ORDER BY "int64s" DESC NULLS LAST`)) + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest" ORDER BY "int64s" DESC NULLS LAST`)) rdr, rows, err := stmt.ExecuteQuery(s.ctx) s.Require().NoError(err) if rows != -1 && rows != 3 { @@ -871,7 +871,7 @@ func (s *StatementTests) TestSqlIngestAppend() { } // use order by clause to ensure we get the same order as the input batch - s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest ORDER BY "int64s" DESC NULLS LAST`)) + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest" ORDER BY "int64s" DESC NULLS LAST`)) rdr, rows, err := stmt.ExecuteQuery(s.ctx) s.Require().NoError(err) if rows != -1 && rows != 3 { @@ -945,7 +945,7 @@ func (s *StatementTests) TestSqlIngestReplace() { s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected) } - s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`)) + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest"`)) rdr, rows, err := stmt.ExecuteQuery(s.ctx) s.Require().NoError(err) if rows != -1 && rows != 1 { @@ -1010,7 +1010,7 @@ func (s *StatementTests) TestSqlIngestCreateAppend() { } // validate - s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`)) + s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest"`)) rdr, rows, err := stmt.ExecuteQuery(s.ctx) s.Require().NoError(err) if rows != -1 && rows != 2 {