Skip to content

Commit

Permalink
RD-14980: Support for INSERT, UPDATE, DELETE
Browse files Browse the repository at this point in the history
  • Loading branch information
bgaidioz committed Nov 29, 2024
1 parent fd270d1 commit aff7a79
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 19 deletions.
12 changes: 10 additions & 2 deletions src/main/scala/com/rawlabs/das/databricks/DASDatabricks.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ package com.rawlabs.das.databricks

import com.databricks.sdk.WorkspaceClient
import com.databricks.sdk.core.DatabricksConfig
import com.databricks.sdk.service.catalog.ListTablesRequest
import com.databricks.sdk.service.catalog.{GetTableRequest, ListTablesRequest}
import com.databricks.sdk.service.sql.ListWarehousesRequest
import com.rawlabs.das.sdk.{DASFunction, DASSdk, DASTable}
import com.rawlabs.protocol.das.{FunctionDefinition, TableDefinition}
Expand All @@ -41,7 +41,15 @@ class DASDatabricks(options: Map[String, String]) extends DASSdk {
val databricksTables = databricksClient.tables().list(req)
val tables = mutable.Map.empty[String, DASDatabricksTable]
databricksTables.forEach { databricksTable =>
tables.put(databricksTable.getName, new DASDatabricksTable(databricksClient, warehouse, databricksTable))
// `databricksTable` is a `TableInfo` and its `getTableConstraints` permits us to know
// if it has a primary key column, which we could use for UPDATE calls. But it's not populated.
// We have to issue an individual `GetTableRequest` call (the single table one, that returns the same
// object but with constraints provided).
val tableDetails = {
val tableReq = new GetTableRequest().setFullName(catalog + '.' + schema + '.' + databricksTable.getName)
databricksClient.tables().get(tableReq)
}
tables.put(databricksTable.getName, new DASDatabricksTable(databricksClient, warehouse, tableDetails))
}
tables.toMap
}
Expand Down
144 changes: 130 additions & 14 deletions src/main/scala/com/rawlabs/das/databricks/DASDatabricksTable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,21 @@ package com.rawlabs.das.databricks
import com.databricks.sdk.WorkspaceClient
import com.databricks.sdk.service.catalog.{ColumnInfo, ColumnTypeName, TableInfo}
import com.databricks.sdk.service.sql._
import com.rawlabs.das.sdk.{DASExecuteResult, DASTable}
import com.rawlabs.das.sdk.{DASExecuteResult, DASSdkException, DASTable}
import com.rawlabs.protocol.das._
import com.rawlabs.protocol.raw.{Type, Value}
import com.typesafe.scalalogging.StrictLogging

import scala.annotation.tailrec
import scala.collection.JavaConverters.collectionAsScalaIterableConverter
import scala.collection.mutable

class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databricksTable: TableInfo)
extends DASTable
with StrictLogging {

private val tableFullName = databricksTable.getSchemaName + '.' + databricksTable.getName

override def getRelSize(quals: Seq[Qual], columns: Seq[String]): (Int, Int) = REL_SIZE

override def execute(
Expand All @@ -36,8 +39,7 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
maybeLimit: Option[Long]
): DASExecuteResult = {
val databricksColumns = if (columns.isEmpty) Seq("NULL") else columns.map(databricksColumnName)
var query =
s"SELECT ${databricksColumns.mkString(",")} FROM " + databricksTable.getSchemaName + '.' + databricksTable.getName
var query = s"SELECT ${databricksColumns.mkString(",")} FROM " + tableFullName
val stmt = new ExecuteStatementRequest()
val parameters = new java.util.LinkedList[StatementParameterListItem]
if (quals.nonEmpty) {
Expand Down Expand Up @@ -93,9 +95,11 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick

stmt.setStatement(query).setWarehouseId(warehouseID).setDisposition(Disposition.INLINE).setFormat(Format.JSON_ARRAY)
val executeAPI = client.statementExecution()
val response1 = executeAPI.executeStatement(stmt)
val response = getResult(response1)
new DASDatabricksExecuteResult(executeAPI, response)
val response = executeAPI.executeStatement(stmt)
getResult(response) match {
case Left(error) => throw new DASSdkException(error)
case Right(result) => new DASDatabricksExecuteResult(executeAPI, result)
}
}

private def databricksColumnName(name: String): String = {
Expand All @@ -119,21 +123,19 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
override def canSort(sortKeys: Seq[SortKey]): Seq[SortKey] = sortKeys

@tailrec
private def getResult(response: StatementResponse): StatementResponse = {
private def getResult(response: StatementResponse): Either[String, StatementResponse] = {
val state = response.getStatus.getState
logger.info(s"Query ${response.getStatementId} state: $state")
state match {
case StatementState.PENDING | StatementState.RUNNING =>
logger.info(s"Query is still running, polling again in $POLLING_TIME ms")
Thread.sleep(POLLING_TIME)
val response2 = client.statementExecution().getStatement(response.getStatementId)
getResult(response2)
case StatementState.SUCCEEDED => response
case StatementState.FAILED =>
throw new RuntimeException(s"Query failed: ${response.getStatus.getError.getMessage}")
case StatementState.CLOSED =>
throw new RuntimeException(s"Query closed: ${response.getStatus.getError.getMessage}")
case StatementState.CANCELED =>
throw new RuntimeException(s"Query canceled: ${response.getStatus.getError.getMessage}")
case StatementState.SUCCEEDED => Right(response)
case StatementState.FAILED => Left(s"Query failed: ${response.getStatus.getError.getMessage}")
case StatementState.CLOSED => Left(s"Query closed: ${response.getStatus.getError.getMessage}")
case StatementState.CANCELED => Left(s"Query canceled: ${response.getStatus.getError.getMessage}")
}
}

Expand Down Expand Up @@ -161,6 +163,26 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
definition.build()
}

// Potential primary key column name found in constraints table metadata.
private var primaryKeyColumn: Option[String] = None

// Try to find a primary key constraint over one column.
if (databricksTable.getTableConstraints == null) {
logger.warn(s"No constraints found for table $tableFullName")
} else {
databricksTable.getTableConstraints.forEach { constraint =>
val primaryKeyConstraint = constraint.getPrimaryKeyConstraint
if (primaryKeyConstraint != null) {
if (primaryKeyConstraint.getChildColumns.size != 1) {
logger.warn("Ignoring composite primary key")
} else {
primaryKeyColumn = Some(primaryKeyConstraint.getChildColumns.iterator().next())
logger.info(s"Found primary key ($primaryKeyColumn)")
}
}
}
}

private def columnType(info: ColumnInfo): Option[Type] = {
val builder = Type.newBuilder()
val columnType = info.getTypeName
Expand Down Expand Up @@ -230,6 +252,11 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
}
}

override def uniqueColumn: String = {
// Return the first column if none.
primaryKeyColumn.getOrElse(databricksTable.getColumns.asScala.head.getName)
}

private def rawValueToParameter(v: Value): StatementParameterListItem = {
logger.debug(s"Converting value to parameter: $v")
val parameter = new StatementParameterListItem()
Expand Down Expand Up @@ -286,4 +313,93 @@ class DASDatabricksTable(client: WorkspaceClient, warehouseID: String, databrick
}
}

override def insert(row: Row): Row = {
bulkInsert(Seq(row)).head
}

// INSERTs can be done in batches, but by inlining values in the query string.
// We don't want to send gigantic query strings accidentally. We try to keep
// queries around that size.
private val MAX_INSERT_CODE_SIZE = 2048

override def bulkInsert(rows: Seq[Row]): Seq[Row] = {
// There's no bulk call in Databricks, we inline values. We build a
// batches of query strings that are at most of MAX_INSERT_CODE_SIZE and
// loop until all rows are consumed.
val columnNames = databricksTable.getColumns.asScala.map(_.getName)
val values = rows.map { row =>
val data = row.getDataMap
columnNames
.map { name =>
val value = data.get(name)
if (value == null) {
"DEFAULT"
} else {
rawValueToDatabricksQueryString(value)
}
}
.mkString("(", ",", ")")
}
val stmt = new ExecuteStatementRequest()
.setWarehouseId(warehouseID)
.setDisposition(Disposition.INLINE)
.setFormat(Format.JSON_ARRAY)

val items = values.iterator
while (items.nonEmpty) {
val item = items.next()
val code = StringBuilder.newBuilder
code.append(s"INSERT INTO ${databricksTable.getName} VALUES $item")
while (code.size < MAX_INSERT_CODE_SIZE && items.hasNext) {
code.append(s",${items.next()}")
}
stmt.setStatement(code.toString())
val executeAPI = client.statementExecution()
val response = executeAPI.executeStatement(stmt)
getResult(response).left.foreach(error => throw new RuntimeException(error))
}
rows
}

override def delete(rowId: Value): Unit = {
val stmt = new ExecuteStatementRequest()
.setWarehouseId(warehouseID)
.setDisposition(Disposition.INLINE)
.setFormat(Format.JSON_ARRAY)
stmt.setStatement(
s"DELETE FROM ${databricksTable.getName} WHERE ${databricksColumnName(uniqueColumn)} = ${rawValueToDatabricksQueryString(rowId)}"
)
val executeAPI = client.statementExecution()
val response = executeAPI.executeStatement(stmt)
getResult(response).left.foreach(error => throw new RuntimeException(error))
}

// How many rows are accepted in a batch update. Technically we're unlimited
// since updates are sent one by one.
private val MODIFY_BATCH_SIZE = 1000

override def modifyBatchSize: Int = {
MODIFY_BATCH_SIZE
}

override def update(rowId: Value, newValues: Row): Row = {
val buffer = mutable.Buffer.empty[String]
newValues.getDataMap
.forEach {
case (name, value) =>
buffer.append(s"${databricksColumnName(name)} = ${rawValueToDatabricksQueryString(value)}")
}
val setValues = buffer.mkString(", ")
val stmt = new ExecuteStatementRequest()
.setWarehouseId(warehouseID)
.setDisposition(Disposition.INLINE)
.setFormat(Format.JSON_ARRAY)
stmt.setStatement(
s"UPDATE ${databricksTable.getName} SET $setValues WHERE ${databricksColumnName(uniqueColumn)} = ${rawValueToDatabricksQueryString(rowId)}"
)
val executeAPI = client.statementExecution()
val response = executeAPI.executeStatement(stmt)
getResult(response).left.foreach(error => throw new RuntimeException(error))
newValues
}
}
3 changes: 0 additions & 3 deletions src/main/scala/com/rawlabs/das/databricks/id.kt

This file was deleted.

0 comments on commit aff7a79

Please sign in to comment.