Skip to content

Commit

Permalink
RD-11521: report connection failures to the user (#469)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgaidioz authored Aug 2, 2024
1 parent ab4e86a commit 3bb31a3
Show file tree
Hide file tree
Showing 5 changed files with 722 additions and 97 deletions.
170 changes: 102 additions & 68 deletions sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import raw.client.sql.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter
import raw.utils.{RawSettings, RawUtils}

import java.io.{IOException, OutputStream}
import java.sql.ResultSet
import java.sql.{ResultSet, SQLException}
import scala.util.control.NonFatal

/**
Expand Down Expand Up @@ -86,66 +86,82 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
safeParse(source) match {
case Left(errors) => GetProgramDescriptionFailure(errors)
case Right(parsedTree) =>
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val stmt = new NamedParametersPreparedStatement(conn, parsedTree)
val description = stmt.queryMetadata match {
case Right(info) =>
val queryParamInfo = info.parameters
val outputType = pgRowTypeToIterableType(info.outputType)
val parameterInfo = queryParamInfo
.map {
case (name, paramInfo) => SqlTypesUtils.rawTypeFromPgType(paramInfo.pgType).map { rawType =>
// we ignore tipe.nullable and mark all parameters as nullable
val paramType = rawType match {
case RawAnyType() => rawType;
case other => other.cloneNullable
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val stmt = new NamedParametersPreparedStatement(conn, parsedTree)
val description = stmt.queryMetadata match {
case Right(info) =>
val queryParamInfo = info.parameters
val outputType = pgRowTypeToIterableType(info.outputType)
val parameterInfo = queryParamInfo
.map {
case (name, paramInfo) => SqlTypesUtils.rawTypeFromPgType(paramInfo.pgType).map { rawType =>
// we ignore tipe.nullable and mark all parameters as nullable
val paramType = rawType match {
case RawAnyType() => rawType;
case other => other.cloneNullable
}
ParamDescription(
name,
Some(paramType),
paramInfo.default,
comment = paramInfo.comment,
required = paramInfo.default.isEmpty
)
}
ParamDescription(
name,
Some(paramType),
paramInfo.default,
comment = paramInfo.comment,
required = paramInfo.default.isEmpty
)
}
}
.foldLeft(Right(Seq.empty): Either[Seq[String], Seq[ParamDescription]]) {
case (Left(errors), Left(error)) => Left(errors :+ error)
case (_, Left(error)) => Left(Seq(error))
case (Right(params), Right(param)) => Right(params :+ param)
case (errors @ Left(_), _) => errors
case (_, Right(param)) => Right(Seq(param))
}
(outputType, parameterInfo) match {
case (Right(iterableType), Right(ps)) =>
// Regardless if there are parameters, we declare a main function with the output type.
// This permits the publish endpoints from the UI (https://raw-labs.atlassian.net/browse/RD-10359)
val ok = ProgramDescription(
Map.empty,
Some(DeclDescription(Some(ps.toVector), Some(iterableType), None)),
None
)
GetProgramDescriptionSuccess(ok)
case _ =>
val errorMessages =
outputType.left.getOrElse(Seq.empty) ++ parameterInfo.left.getOrElse(Seq.empty)
GetProgramDescriptionFailure(treeErrors(parsedTree, errorMessages).toList)
}
.foldLeft(Right(Seq.empty): Either[Seq[String], Seq[ParamDescription]]) {
case (Left(errors), Left(error)) => Left(errors :+ error)
case (_, Left(error)) => Left(Seq(error))
case (Right(params), Right(param)) => Right(params :+ param)
case (errors @ Left(_), _) => errors
case (_, Right(param)) => Right(Seq(param))
}
(outputType, parameterInfo) match {
case (Right(iterableType), Right(ps)) =>
// Regardless if there are parameters, we declare a main function with the output type.
// This permits the publish endpoints from the UI (https://raw-labs.atlassian.net/browse/RD-10359)
val ok = ProgramDescription(
Map.empty,
Some(DeclDescription(Some(ps.toVector), Some(iterableType), None)),
None
)
GetProgramDescriptionSuccess(ok)
case _ =>
val errorMessages = outputType.left.getOrElse(Seq.empty) ++ parameterInfo.left.getOrElse(Seq.empty)
GetProgramDescriptionFailure(treeErrors(parsedTree, errorMessages).toList)
}
case Left(errors) => GetProgramDescriptionFailure(errors)
case Left(errors) => GetProgramDescriptionFailure(errors)
}
RawUtils.withSuppressNonFatalException(stmt.close())
description
} catch {
case e: NamedParametersPreparedStatementException => GetProgramDescriptionFailure(e.errors)
} finally {
RawUtils.withSuppressNonFatalException(conn.close())
}
RawUtils.withSuppressNonFatalException(stmt.close())
description
} catch {
case e: NamedParametersPreparedStatementException => GetProgramDescriptionFailure(e.errors)
} finally {
RawUtils.withSuppressNonFatalException(conn.close())
case ex: SQLException if isConnectionFailure(ex) =>
logger.warn("SqlConnectionPool connection failure", ex)
GetProgramDescriptionFailure(List(treeError(parsedTree, ex.getMessage)))
}
}
} catch {
case NonFatal(t) => throw new CompilerServiceException(t, environment)
}
}

private def treeError(parsedTree: ParseProgramResult, msg: String) = {
val tree = parsedTree.tree
val start = parsedTree.positions.getStart(tree).get
val end = parsedTree.positions.getFinish(tree).get
val startPos = ErrorPosition(start.line, start.column)
val endPos = ErrorPosition(end.line, end.column)
ErrorMessage(msg, List(ErrorRange(startPos, endPos)), ErrorCode.SqlErrorCode)
}

override def execute(
source: String,
environment: ProgramEnvironment,
Expand Down Expand Up @@ -182,9 +198,11 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
} finally {
RawUtils.withSuppressNonFatalException(conn.close())
}

}
} catch {
case ex: SQLException if isConnectionFailure(ex) =>
logger.warn("SqlConnectionPool connection failure", ex)
ExecutionRuntimeFailure(ex.getMessage)
case NonFatal(t) => throw new CompilerServiceException(t, environment)
}
}
Expand Down Expand Up @@ -319,21 +337,27 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
.map { case (names, tipe) => HoverResponse(Some(TypeCompletion(formatIdns(names), tipe))) }
.getOrElse(HoverResponse(None))
case use: SqlParamUseNode =>
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val pstmt = new NamedParametersPreparedStatement(conn, tree)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
pstmt.parameterInfo(use.name) match {
case Right(typeInfo) => HoverResponse(Some(TypeCompletion(use.name, typeInfo.pgType.typeName)))
case Left(_) => HoverResponse(None)
val pstmt = new NamedParametersPreparedStatement(conn, tree)
try {
pstmt.parameterInfo(use.name) match {
case Right(typeInfo) => HoverResponse(Some(TypeCompletion(use.name, typeInfo.pgType.typeName)))
case Left(_) => HoverResponse(None)
}
} finally {
RawUtils.withSuppressNonFatalException(pstmt.close())
}
} catch {
case _: NamedParametersPreparedStatementException => HoverResponse(None)
} finally {
RawUtils.withSuppressNonFatalException(pstmt.close())
RawUtils.withSuppressNonFatalException(conn.close())
}
} catch {
case _: NamedParametersPreparedStatementException => HoverResponse(None)
} finally {
RawUtils.withSuppressNonFatalException(conn.close())
case ex: SQLException if isConnectionFailure(ex) =>
logger.warn("SqlConnectionPool connection failure", ex)
HoverResponse(None)
}
}
.getOrElse(HoverResponse(None))
Expand Down Expand Up @@ -372,21 +396,27 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
safeParse(source) match {
case Left(errors) => ValidateResponse(errors)
case Right(parsedTree) =>
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val stmt = new NamedParametersPreparedStatement(conn, parsedTree)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
stmt.queryMetadata match {
case Right(_) => ValidateResponse(List.empty)
case Left(errors) => ValidateResponse(errors)
val stmt = new NamedParametersPreparedStatement(conn, parsedTree)
try {
stmt.queryMetadata match {
case Right(_) => ValidateResponse(List.empty)
case Left(errors) => ValidateResponse(errors)
}
} finally {
RawUtils.withSuppressNonFatalException(stmt.close())
}
} catch {
case e: NamedParametersPreparedStatementException => ValidateResponse(e.errors)
} finally {
RawUtils.withSuppressNonFatalException(stmt.close())
RawUtils.withSuppressNonFatalException(conn.close())
}
} catch {
case e: NamedParametersPreparedStatementException => ValidateResponse(e.errors)
} finally {
RawUtils.withSuppressNonFatalException(conn.close())
case ex: SQLException if isConnectionFailure(ex) =>
logger.warn("SqlConnectionPool connection failure", ex)
ValidateResponse(List(treeError(parsedTree, ex.getMessage)))
}
}
} catch {
Expand Down Expand Up @@ -421,4 +451,8 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
rowAttrTypes.right.map(attrs => RawIterableType(RawRecordType(attrs.toVector, false, false), false, false))
}

private def isConnectionFailure(ex: SQLException) = {
val state = ex.getSQLState
state != null && state.startsWith("08") // connection exception, SqlConnectionPool is full
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@ class SqlConnectionPool()(implicit settings: RawSettings) extends RawService wit
// If no connection is currently available, just check if this specific db location is maxed out.
if (maybeConn.isEmpty && conns.size >= maxConnectionsPerDb) {
// No connection was available to borrow, and too many being used for this db location, so we cannot open
// any more!
throw new SQLException("too many connections active")
// any more! We throw with code 08000 (SQLSTATE "connection exception" in SQL standard).
throw new SQLException("too many connections active", "08000")
}
}

Expand All @@ -165,7 +165,7 @@ class SqlConnectionPool()(implicit settings: RawSettings) extends RawService wit

// We could not successfully release any connection, so bail out.
if (getTotalActiveConnections() >= maxConnections) {
throw new SQLException("no connections available")
throw new SQLException("no connections available", "08000")
}

// Create a new connection.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ package raw.client.sql.metadata

import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.typesafe.scalalogging.StrictLogging

import java.time.Duration
import raw.client.sql.antlr4.{SqlIdentifierNode, SqlIdnNode, SqlProjNode}
import raw.client.sql.{SqlConnectionPool, SqlIdentifier}

import java.sql.SQLException

case class IdentifierInfo(name: Seq[SqlIdentifier], tipe: String)

/* This class is used to cache metadata info about the user's database.
Expand All @@ -34,17 +37,23 @@ class UserMetadataCache(jdbcUrl: String, connectionPool: SqlConnectionPool, maxS
private val wordCompletionCache = {
val loader = new CacheLoader[Seq[SqlIdentifier], Seq[IdentifierInfo]]() {
override def load(idns: Seq[SqlIdentifier]): Seq[IdentifierInfo] = {
val con = connectionPool.getConnection(jdbcUrl)
try {
val query = idns.size match {
case 3 => WordSearchWithThreeItems
case 2 => WordSearchWithTwoItems
case 1 => WordSearchWithOneItem
val con = connectionPool.getConnection(jdbcUrl)
try {
val query = idns.size match {
case 3 => WordSearchWithThreeItems
case 2 => WordSearchWithTwoItems
case 1 => WordSearchWithOneItem
}
val tokens = idns.map(idn => if (idn.quoted) idn.value else idn.value.toLowerCase)
query.run(con, tokens)
} finally {
con.close()
}
val tokens = idns.map(idn => if (idn.quoted) idn.value else idn.value.toLowerCase)
query.run(con, tokens)
} finally {
con.close()
} catch {
case ex: SQLException if isConnectionFailure(ex) =>
logger.warn("SqlConnectionPool connection failure", ex)
Seq.empty
}
}
}
Expand Down Expand Up @@ -100,16 +109,22 @@ class UserMetadataCache(jdbcUrl: String, connectionPool: SqlConnectionPool, maxS
private val dotCompletionCache = {
val loader = new CacheLoader[Seq[SqlIdentifier], Seq[IdentifierInfo]]() {
override def load(idns: Seq[SqlIdentifier]): Seq[IdentifierInfo] = {
val con = connectionPool.getConnection(jdbcUrl)
try {
val query = idns.size match {
case 2 => DotSearchWithTwoItems
case 1 => DotSearchWithOneItem
val con = connectionPool.getConnection(jdbcUrl)
try {
val query = idns.size match {
case 2 => DotSearchWithTwoItems
case 1 => DotSearchWithOneItem
}
val tokens = idns.map(idn => if (idn.quoted) idn.value else idn.value.toLowerCase)
query.run(con, tokens)
} finally {
con.close()
}
val tokens = idns.map(idn => if (idn.quoted) idn.value else idn.value.toLowerCase)
query.run(con, tokens)
} finally {
con.close()
} catch {
case ex: SQLException if isConnectionFailure(ex) =>
logger.warn("SqlConnectionPool connection failure", ex)
Seq.empty
}
}
}
Expand All @@ -132,4 +147,8 @@ class UserMetadataCache(jdbcUrl: String, connectionPool: SqlConnectionPool, maxS
dotCompletionCache.get(seq).map(i => (i.name, i.tipe))
}

private def isConnectionFailure(ex: SQLException) = {
val state = ex.getSQLState
state != null && state.startsWith("08") // connection exception, SqlConnectionPool is full
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,6 @@
* licenses/APL.txt.
*/

/* we probably need to make parameters optional and replace them by null when not specified. No default value
where does one get the credentials from?
let's try to find a decent library that deals with parameters?
since we run on Postgres we can use the Postgres JDBC driver exceptions and be more specific with error handling
- some messages show line and column numbers, we could use that to provide more precise error messages
- but it will be messed up if we don't account for question marks
one should return correct runtime/validation failures
*/

package raw.client.sql

import com.dimafeng.testcontainers.{ForAllTestContainer, PostgreSQLContainer}
Expand Down
Loading

0 comments on commit 3bb31a3

Please sign in to comment.