Skip to content

Commit

Permalink
fix(c/driver/sqlite): escape table names in INSERT, too (#1003)
Browse files Browse the repository at this point in the history
Fixes #1000.
  • Loading branch information
lidavidm authored Sep 1, 2023
1 parent d772fd1 commit 932b721
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 29 deletions.
57 changes: 28 additions & 29 deletions c/driver/sqlite/sqlite.c
Original file line number Diff line number Diff line change
Expand Up @@ -1081,26 +1081,28 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
sqlite3_str* create_query = sqlite3_str_new(NULL);
if (sqlite3_str_errcode(create_query)) {
SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn));
sqlite3_free(sqlite3_str_finish(create_query));
return ADBC_STATUS_INTERNAL;
}
struct StringBuilder insert_query = {0};

if (StringBuilderInit(&insert_query, /*initial_size=*/256) != 0) {
SetError(error, "[SQLite] Could not initiate StringBuilder");
sqlite3_str* insert_query = sqlite3_str_new(NULL);
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn));
sqlite3_free(sqlite3_str_finish(create_query));
sqlite3_free(sqlite3_str_finish(insert_query));
return ADBC_STATUS_INTERNAL;
}

sqlite3_str_appendf(create_query, "%s%Q%s", "CREATE TABLE ", stmt->target_table, " (");
sqlite3_str_appendf(create_query, "CREATE TABLE %Q (", stmt->target_table);
if (sqlite3_str_errcode(create_query)) {
SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn));
SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

if (StringBuilderAppend(&insert_query, "%s%s%s", "INSERT INTO ", stmt->target_table,
" VALUES (") != 0) {
SetError(error, "[SQLite] Call to StringBuilderAppend failed");
sqlite3_str_appendf(insert_query, "INSERT INTO %Q VALUES (", stmt->target_table);
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
Expand All @@ -1111,23 +1113,24 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
if (i > 0) {
sqlite3_str_appendf(create_query, "%s", ", ");
if (sqlite3_str_errcode(create_query)) {
SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn));
SetError(error, "[SQLite] Failed to build CREATE: %s",
sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
}

sqlite3_str_appendf(create_query, "%Q", stmt->binder.schema.children[i]->name);
if (sqlite3_str_errcode(create_query)) {
SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn));
SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

int status =
ArrowSchemaViewInit(&view, stmt->binder.schema.children[i], &arrow_error);
if (status != 0) {
SetError(error, "Failed to parse schema for column %d: %s (%d): %s", i,
SetError(error, "[SQLite] Failed to parse schema for column %d: %s (%d): %s", i,
strerror(status), status, arrow_error.message);
code = ADBC_STATUS_INTERNAL;
goto cleanup;
Expand Down Expand Up @@ -1160,30 +1163,24 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
break;
}

if (i > 0) {
if (StringBuilderAppend(&insert_query, "%s", ", ") != 0) {
SetError(error, "[SQLite] Call to StringBuilderAppend failed");
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
}

if (StringBuilderAppend(&insert_query, "%s", "?") != 0) {
SetError(error, "[SQLite] Call to StringBuilderAppend failed");
sqlite3_str_appendf(insert_query, "%s?", (i > 0 ? ", " : ""));
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
}

sqlite3_str_appendchar(create_query, 1, ')');
if (sqlite3_str_errcode(create_query)) {
SetError(error, "[SQLite] %s", sqlite3_errmsg(stmt->conn));
SetError(error, "[SQLite] Failed to build CREATE: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}

if (StringBuilderAppend(&insert_query, "%s", ")") != 0) {
SetError(error, "[SQLite] Call to StringBuilderAppend failed");
sqlite3_str_appendchar(insert_query, 1, ')');
if (sqlite3_str_errcode(insert_query)) {
SetError(error, "[SQLite] Failed to build INSERT: %s", sqlite3_errmsg(stmt->conn));
code = ADBC_STATUS_INTERNAL;
goto cleanup;
}
Expand All @@ -1207,11 +1204,13 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,
}

if (code == ADBC_STATUS_OK) {
int rc = sqlite3_prepare_v2(stmt->conn, insert_query.buffer, (int)insert_query.size,
insert_statement, /*pzTail=*/NULL);
int rc = sqlite3_prepare_v2(stmt->conn, sqlite3_str_value(insert_query),
sqlite3_str_length(insert_query), insert_statement,
/*pzTail=*/NULL);
if (rc != SQLITE_OK) {
SetError(error, "[SQLite] Failed to prepare statement: %s (executed '%s')",
sqlite3_errmsg(stmt->conn), insert_query.buffer);
SetError(error, "[SQLite] Failed to prepare statement: %s (executed '%.*s')",
sqlite3_errmsg(stmt->conn), sqlite3_str_length(insert_query),
sqlite3_str_value(insert_query));
code = ADBC_STATUS_INTERNAL;
}
}
Expand All @@ -1220,7 +1219,7 @@ AdbcStatusCode SqliteStatementInitIngest(struct SqliteStatement* stmt,

cleanup:
sqlite3_free(sqlite3_str_finish(create_query));
StringBuilderReset(&insert_query);
sqlite3_free(sqlite3_str_finish(insert_query));
return code;
}

Expand Down
31 changes: 31 additions & 0 deletions c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,37 @@ class SqliteStatementTest : public ::testing::Test,
};
ADBCV_TEST_STATEMENT(SqliteStatementTest)

TEST_F(SqliteStatementTest, SqlIngestNameEscaping) {
ASSERT_THAT(quirks()->DropTable(&connection, "\"test-table\"", &error),
adbc_validation::IsOkStatus(&error));

std::string table = "test-table";
adbc_validation::Handle<struct ArrowSchema> schema;
adbc_validation::Handle<struct ArrowArray> array;
struct ArrowError na_error;
ASSERT_THAT(
adbc_validation::MakeSchema(&schema.value, {{"index", NANOARROW_TYPE_INT64},
{"create", NANOARROW_TYPE_STRING}}),
adbc_validation::IsOkErrno());
ASSERT_THAT((adbc_validation::MakeBatch<int64_t, std::string>(
&schema.value, &array.value, &na_error, {42, -42, std::nullopt},
{"foo", std::nullopt, ""})),
adbc_validation::IsOkErrno(&na_error));

ASSERT_THAT(AdbcStatementNew(&connection, &statement, &error),
adbc_validation::IsOkStatus(&error));
ASSERT_THAT(AdbcStatementSetOption(&statement, ADBC_INGEST_OPTION_TARGET_TABLE,
table.c_str(), &error),
adbc_validation::IsOkStatus(&error));
ASSERT_THAT(AdbcStatementBind(&statement, &array.value, &schema.value, &error),
adbc_validation::IsOkStatus(&error));

int64_t rows_affected = 0;
ASSERT_THAT(AdbcStatementExecuteQuery(&statement, nullptr, &rows_affected, &error),
adbc_validation::IsOkStatus(&error));
ASSERT_EQ(3, rows_affected);
}

// -- SQLite Specific Tests ------------------------------------------

constexpr size_t kInferRows = 16;
Expand Down

0 comments on commit 932b721

Please sign in to comment.