Skip to content

Commit

Permalink
more test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
zeroshade committed Feb 23, 2024
1 parent 7c2da50 commit 3c5c3b6
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 29 deletions.
7 changes: 6 additions & 1 deletion c/driver/snowflake/snowflake_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
AdbcStatusCode SetupDatabase(struct AdbcDatabase* database,
struct AdbcError* error) const override {
EXPECT_THAT(AdbcDatabaseSetOption(database, "uri", uri_, error), IsOkStatus(error));
EXPECT_THAT(AdbcDatabaseSetOption(
database, "adbc.snowflake.sql.client_option.use_high_precision",
"false", error),
IsOkStatus(error));
return ADBC_STATUS_OK;
}

Expand Down Expand Up @@ -119,6 +123,7 @@ class SnowflakeQuirks : public adbc_validation::DriverQuirks {
bool supports_metadata_current_db_schema() const override { return false; }
bool supports_partitioned_data() const override { return false; }
bool supports_dynamic_parameter_binding() const override { return false; }
bool supports_error_on_incompatible_schema() const override { return false; }
bool ddl_implicit_commit_txn() const override { return true; }
std::string db_schema() const override { return "ADBC_TESTING"; }

Expand Down Expand Up @@ -204,7 +209,7 @@ class SnowflakeStatementTest : public ::testing::Test,
expected = {std::nullopt, -42, 0, 42};
break;
case NANOARROW_TIME_UNIT_MILLI:
expected = {std::nullopt, -42000, 0, 42000};
expected = {std::nullopt, -42, 0, 42};
break;
case NANOARROW_TIME_UNIT_MICRO:
expected = {std::nullopt, -42, 0, 42};
Expand Down
4 changes: 4 additions & 0 deletions c/validation/adbc_validation.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ class DriverQuirks {
/// \brief Whether we can get statistics
virtual bool supports_statistics() const { return false; }

/// \brief Whether ingest errors on an incompatible schema or simply performs
/// column matching.
virtual bool supports_error_on_incompatible_schema() const { return true; }

/// \brief Default catalog to use for tests
virtual std::string catalog() const { return ""; }

Expand Down
22 changes: 12 additions & 10 deletions c/validation/adbc_validation_statement.cc
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ void StatementTest::TestSqlIngestInterval() {

void StatementTest::TestSqlIngestStringDictionary() {
ASSERT_NO_FATAL_FAILURE(TestSqlIngestType<std::string>(
NANOARROW_TYPE_STRING, {std::nullopt, "", "", "1234", ""},
NANOARROW_TYPE_STRING, {"", "", "1234", ""},
/*dictionary_encode*/ true));
}

Expand Down Expand Up @@ -865,6 +865,8 @@ void StatementTest::TestSqlIngestErrors() {
::testing::Not(IsOkStatus(&error)));
if (error.release) error.release(&error);

if (!quirks()->supports_error_on_incompatible_schema()) { return; }

// ...then try to append an incompatible schema
ASSERT_THAT(MakeSchema(&schema.value, {{"int64s", NANOARROW_TYPE_INT64},
{"coltwo", NANOARROW_TYPE_INT64}}),
Expand Down Expand Up @@ -2212,23 +2214,23 @@ void StatementTest::TestSqlQueryInsertRollback() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

ASSERT_THAT(
AdbcStatementSetSqlQuery(&statement, "CREATE TABLE rollbacktest (a INT)", &error),
AdbcStatementSetSqlQuery(&statement, "CREATE TABLE \"rollbacktest\" (a INT)", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

ASSERT_THAT(AdbcConnectionCommit(&connection, &error), IsOkStatus(&error));

ASSERT_THAT(AdbcStatementSetSqlQuery(&statement,
"INSERT INTO rollbacktest (a) VALUES (1)", &error),
"INSERT INTO \"rollbacktest\" (a) VALUES (1)", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

ASSERT_THAT(AdbcConnectionRollback(&connection, &error), IsOkStatus(&error));

adbc_validation::StreamReader reader;
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM rollbacktest", &error),
ASSERT_THAT(AdbcStatementSetSqlQuery(&statement, "SELECT * FROM \"rollbacktest\"", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, &reader.stream.value,
&reader.rows_affected, &error),
Expand Down Expand Up @@ -2309,20 +2311,20 @@ void StatementTest::TestSqlQueryRowsAffectedDelete() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

ASSERT_THAT(
AdbcStatementSetSqlQuery(&statement, "CREATE TABLE delete_test (foo INT)", &error),
AdbcStatementSetSqlQuery(&statement, "CREATE TABLE \"delete_test\" (foo INT)", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

ASSERT_THAT(AdbcStatementSetSqlQuery(
&statement,
"INSERT INTO delete_test (foo) VALUES (1), (2), (3), (4), (5)", &error),
"INSERT INTO \"delete_test\" (foo) VALUES (1), (2), (3), (4), (5)", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

ASSERT_THAT(AdbcStatementSetSqlQuery(&statement,
"DELETE FROM delete_test WHERE foo >= 3", &error),
"DELETE FROM \"delete_test\" WHERE foo >= 3", &error),
IsOkStatus(&error));

int64_t rows_affected = 0;
Expand All @@ -2337,20 +2339,20 @@ void StatementTest::TestSqlQueryRowsAffectedDeleteStream() {
ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error), IsOkStatus(&error));

ASSERT_THAT(
AdbcStatementSetSqlQuery(&statement, "CREATE TABLE delete_test (foo INT)", &error),
AdbcStatementSetSqlQuery(&statement, "CREATE TABLE \"delete_test\" (foo INT)", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

ASSERT_THAT(AdbcStatementSetSqlQuery(
&statement,
"INSERT INTO delete_test (foo) VALUES (1), (2), (3), (4), (5)", &error),
"INSERT INTO \"delete_test\" (foo) VALUES (1), (2), (3), (4), (5)", &error),
IsOkStatus(&error));
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, nullptr, &error),
IsOkStatus(&error));

ASSERT_THAT(AdbcStatementSetSqlQuery(&statement,
"DELETE FROM delete_test WHERE foo >= 3", &error),
"DELETE FROM \"delete_test\" WHERE foo >= 3", &error),
IsOkStatus(&error));

adbc_validation::StreamReader reader;
Expand Down
20 changes: 19 additions & 1 deletion go/adbc/driver/snowflake/bulk_ingestion.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,15 @@ func (st *statement) ingestRecord(ctx context.Context) (nrows int64, err error)
st.bound = nil
}()

var initialRows int64

// Check final row count of target table to get definitive rows affected
initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, strconv.Quote(st.targetTable))
if err != nil {
st.bound.Release()
return
}

parquetProps, arrowProps := newWriterProps(st.alloc, st.ingestOptions)
g := errgroup.Group{}

Expand Down Expand Up @@ -180,6 +189,7 @@ func (st *statement) ingestRecord(ctx context.Context) (nrows int64, err error)

// Check final row count of target table to get definitive rows affected
nrows, err = countRowsInTable(ctx, st.cnxn.sqldb, strconv.Quote(st.targetTable))
nrows = nrows - initialRows
return
}

Expand All @@ -193,11 +203,19 @@ func (st *statement) ingestStream(ctx context.Context) (nrows int64, err error)
st.streamBind.Release()
st.streamBind = nil
}()

var initialRows int64
// Check final row count of target table to get definitive rows affected
initialRows, err = countRowsInTable(ctx, st.cnxn.sqldb, strconv.Quote(st.targetTable))
if err != nil {
return
}

defer func() {
// Always check the resulting row count, even in the case of an error. We may have ingested part of the data.
ctx := context.Background() // TODO(joellubi): switch to context.WithoutCancel(ctx) once we're on Go 1.21
n, countErr := countRowsInTable(ctx, st.cnxn.sqldb, st.targetTable)
nrows = n
nrows = n - initialRows

// Ingestion, row-count check, or both could have failed
// Wrap any failures as ADBC errors
Expand Down
26 changes: 13 additions & 13 deletions go/adbc/driver/snowflake/driver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ func getArr(arr arrow.Array) interface{} {
func (s *SnowflakeQuirks) CreateSampleTable(tableName string, r arrow.Record) error {
var b strings.Builder
b.WriteString("CREATE OR REPLACE TABLE ")
b.WriteString(tableName)
b.WriteString(strconv.Quote(tableName))
b.WriteString(" (")

for i := 0; i < int(r.NumCols()); i++ {
Expand All @@ -164,7 +164,7 @@ func (s *SnowflakeQuirks) CreateSampleTable(tableName string, r arrow.Record) er
return err
}

insertQuery := "INSERT INTO " + tableName + " VALUES ("
insertQuery := "INSERT INTO " + strconv.Quote(tableName) + " VALUES ("
bindings := strings.Repeat("?,", int(r.NumCols()))
insertQuery += bindings[:len(bindings)-1] + ")"

Expand All @@ -184,7 +184,7 @@ func (s *SnowflakeQuirks) DropTable(cnxn adbc.Connection, tblname string) error
}
defer stmt.Close()

if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` + tblname); err != nil {
if err = stmt.SetSqlQuery(`DROP TABLE IF EXISTS ` + strconv.Quote(tblname)); err != nil {
return err
}

Expand Down Expand Up @@ -486,7 +486,7 @@ func (suite *SnowflakeTests) TestSqlIngestRecordAndStreamAreEquivalent() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_bind ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_bind" ORDER BY "col_int64" ASC`))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()
Expand All @@ -509,7 +509,7 @@ func (suite *SnowflakeTests) TestSqlIngestRecordAndStreamAreEquivalent() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_bind_stream ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_bind_stream" ORDER BY "col_int64" ASC`))
rdr, n, err = suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()
Expand Down Expand Up @@ -596,7 +596,7 @@ func (suite *SnowflakeTests) TestSqlIngestRoundtripTypes() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_roundtrip ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_roundtrip" ORDER BY "col_int64" ASC`))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()
Expand Down Expand Up @@ -672,7 +672,7 @@ func (suite *SnowflakeTests) TestSqlIngestTimestampTypes() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_timestamps ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_timestamps" ORDER BY "col_int64" ASC`))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()
Expand Down Expand Up @@ -784,7 +784,7 @@ func (suite *SnowflakeTests) TestSqlIngestDate64Type() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_date64 ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_date64" ORDER BY "col_int64" ASC`))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()
Expand Down Expand Up @@ -877,7 +877,7 @@ func (suite *SnowflakeTests) TestSqlIngestHighPrecision() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_high_precision ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_high_precision" ORDER BY "col_int64" ASC`))
suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueEnabled))
defer func() {
suite.Require().NoError(suite.stmt.SetOption(driver.OptionUseHighPrecision, adbc.OptionValueDisabled))
Expand Down Expand Up @@ -988,7 +988,7 @@ func (suite *SnowflakeTests) TestSqlIngestLowPrecision() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_high_precision ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_high_precision" ORDER BY "col_int64" ASC`))
// OptionUseHighPrecision already disabled
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
Expand Down Expand Up @@ -1106,7 +1106,7 @@ func (suite *SnowflakeTests) TestSqlIngestStructType() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_struct ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_struct" ORDER BY "col_int64" ASC`))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()
Expand Down Expand Up @@ -1210,7 +1210,7 @@ func (suite *SnowflakeTests) TestSqlIngestMapType() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_map ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_map" ORDER BY "col_int64" ASC`))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()
Expand Down Expand Up @@ -1299,7 +1299,7 @@ func (suite *SnowflakeTests) TestSqlIngestListType() {
suite.Require().NoError(err)
suite.EqualValues(3, n)

suite.Require().NoError(suite.stmt.SetSqlQuery("SELECT * FROM bulk_ingest_list ORDER BY \"col_int64\" ASC"))
suite.Require().NoError(suite.stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest_list" ORDER BY "col_int64" ASC`))
rdr, n, err := suite.stmt.ExecuteQuery(suite.ctx)
suite.Require().NoError(err)
defer rdr.Release()
Expand Down
8 changes: 4 additions & 4 deletions go/adbc/validation/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ func (s *StatementTests) TestSqlIngestInts() {
}

// use order by clause to ensure we get the same order as the input batch
s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest ORDER BY "int64s" DESC NULLS LAST`))
s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest" ORDER BY "int64s" DESC NULLS LAST`))
rdr, rows, err := stmt.ExecuteQuery(s.ctx)
s.Require().NoError(err)
if rows != -1 && rows != 3 {
Expand Down Expand Up @@ -871,7 +871,7 @@ func (s *StatementTests) TestSqlIngestAppend() {
}

// use order by clause to ensure we get the same order as the input batch
s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest ORDER BY "int64s" DESC NULLS LAST`))
s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest" ORDER BY "int64s" DESC NULLS LAST`))
rdr, rows, err := stmt.ExecuteQuery(s.ctx)
s.Require().NoError(err)
if rows != -1 && rows != 3 {
Expand Down Expand Up @@ -945,7 +945,7 @@ func (s *StatementTests) TestSqlIngestReplace() {
s.FailNowf("invalid number of affected rows", "should be -1 or 1, got: %d", affected)
}

s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`))
s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest"`))
rdr, rows, err := stmt.ExecuteQuery(s.ctx)
s.Require().NoError(err)
if rows != -1 && rows != 1 {
Expand Down Expand Up @@ -1010,7 +1010,7 @@ func (s *StatementTests) TestSqlIngestCreateAppend() {
}

// validate
s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM bulk_ingest`))
s.Require().NoError(stmt.SetSqlQuery(`SELECT * FROM "bulk_ingest"`))
rdr, rows, err := stmt.ExecuteQuery(s.ctx)
s.Require().NoError(err)
if rows != -1 && rows != 2 {
Expand Down

0 comments on commit 3c5c3b6

Please sign in to comment.