diff --git a/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala b/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala index 699198b7f..4027d0ed3 100644 --- a/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala +++ b/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala @@ -21,7 +21,7 @@ import raw.client.sql.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter import raw.utils.{RawSettings, RawUtils} import java.io.{IOException, OutputStream} -import java.sql.ResultSet +import java.sql.{ResultSet, SQLException} import scala.util.control.NonFatal /** @@ -86,59 +86,66 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends safeParse(source) match { case Left(errors) => GetProgramDescriptionFailure(errors) case Right(parsedTree) => - val conn = connectionPool.getConnection(environment.jdbcUrl.get) try { - val stmt = new NamedParametersPreparedStatement(conn, parsedTree) - val description = stmt.queryMetadata match { - case Right(info) => - val queryParamInfo = info.parameters - val outputType = pgRowTypeToIterableType(info.outputType) - val parameterInfo = queryParamInfo - .map { - case (name, paramInfo) => SqlTypesUtils.rawTypeFromPgType(paramInfo.pgType).map { rawType => - // we ignore tipe.nullable and mark all parameters as nullable - val paramType = rawType match { - case RawAnyType() => rawType; - case other => other.cloneNullable + val conn = connectionPool.getConnection(environment.jdbcUrl.get) + try { + val stmt = new NamedParametersPreparedStatement(conn, parsedTree) + val description = stmt.queryMetadata match { + case Right(info) => + val queryParamInfo = info.parameters + val outputType = pgRowTypeToIterableType(info.outputType) + val parameterInfo = queryParamInfo + .map { + case (name, paramInfo) => SqlTypesUtils.rawTypeFromPgType(paramInfo.pgType).map { rawType => + // we ignore tipe.nullable and mark all parameters as nullable + val paramType = rawType match { + case RawAnyType() => rawType; + case other => other.cloneNullable + } + ParamDescription( + name, + Some(paramType), + paramInfo.default, + comment = paramInfo.comment, + required = paramInfo.default.isEmpty + ) } - ParamDescription( - name, - Some(paramType), - paramInfo.default, - comment = paramInfo.comment, - required = paramInfo.default.isEmpty - ) - } + } + .foldLeft(Right(Seq.empty): Either[Seq[String], Seq[ParamDescription]]) { + case (Left(errors), Left(error)) => Left(errors :+ error) + case (_, Left(error)) => Left(Seq(error)) + case (Right(params), Right(param)) => Right(params :+ param) + case (errors @ Left(_), _) => errors + case (_, Right(param)) => Right(Seq(param)) + } + (outputType, parameterInfo) match { + case (Right(iterableType), Right(ps)) => + // Regardless if there are parameters, we declare a main function with the output type. + // This permits the publish endpoints from the UI (https://raw-labs.atlassian.net/browse/RD-10359) + val ok = ProgramDescription( + Map.empty, + Some(DeclDescription(Some(ps.toVector), Some(iterableType), None)), + None + ) + GetProgramDescriptionSuccess(ok) + case _ => + val errorMessages = + outputType.left.getOrElse(Seq.empty) ++ parameterInfo.left.getOrElse(Seq.empty) + GetProgramDescriptionFailure(treeErrors(parsedTree, errorMessages).toList) } - .foldLeft(Right(Seq.empty): Either[Seq[String], Seq[ParamDescription]]) { - case (Left(errors), Left(error)) => Left(errors :+ error) - case (_, Left(error)) => Left(Seq(error)) - case (Right(params), Right(param)) => Right(params :+ param) - case (errors @ Left(_), _) => errors - case (_, Right(param)) => Right(Seq(param)) - } - (outputType, parameterInfo) match { - case (Right(iterableType), Right(ps)) => - // Regardless if there are parameters, we declare a main function with the output type. - // This permits the publish endpoints from the UI (https://raw-labs.atlassian.net/browse/RD-10359) - val ok = ProgramDescription( - Map.empty, - Some(DeclDescription(Some(ps.toVector), Some(iterableType), None)), - None - ) - GetProgramDescriptionSuccess(ok) - case _ => - val errorMessages = outputType.left.getOrElse(Seq.empty) ++ parameterInfo.left.getOrElse(Seq.empty) - GetProgramDescriptionFailure(treeErrors(parsedTree, errorMessages).toList) - } - case Left(errors) => GetProgramDescriptionFailure(errors) + case Left(errors) => GetProgramDescriptionFailure(errors) + } + RawUtils.withSuppressNonFatalException(stmt.close()) + description + } catch { + case e: NamedParametersPreparedStatementException => GetProgramDescriptionFailure(e.errors) + } finally { + RawUtils.withSuppressNonFatalException(conn.close()) } - RawUtils.withSuppressNonFatalException(stmt.close()) - description } catch { - case e: NamedParametersPreparedStatementException => GetProgramDescriptionFailure(e.errors) - } finally { - RawUtils.withSuppressNonFatalException(conn.close()) + case ex: SQLException if isConnectionFailure(ex) => + logger.warn("SqlConnectionPool connection failure", ex) + GetProgramDescriptionFailure(List(treeError(parsedTree, ex.getMessage))) } } } catch { @@ -146,6 +153,15 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends } } + private def treeError(parsedTree: ParseProgramResult, msg: String) = { + val tree = parsedTree.tree + val start = parsedTree.positions.getStart(tree).get + val end = parsedTree.positions.getFinish(tree).get + val startPos = ErrorPosition(start.line, start.column) + val endPos = ErrorPosition(end.line, end.column) + ErrorMessage(msg, List(ErrorRange(startPos, endPos)), ErrorCode.SqlErrorCode) + } + override def execute( source: String, environment: ProgramEnvironment, @@ -182,9 +198,11 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends } finally { RawUtils.withSuppressNonFatalException(conn.close()) } - } } catch { + case ex: SQLException if isConnectionFailure(ex) => + logger.warn("SqlConnectionPool connection failure", ex) + ExecutionRuntimeFailure(ex.getMessage) case NonFatal(t) => throw new CompilerServiceException(t, environment) } } @@ -319,21 +337,27 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends .map { case (names, tipe) => HoverResponse(Some(TypeCompletion(formatIdns(names), tipe))) } .getOrElse(HoverResponse(None)) case use: SqlParamUseNode => - val conn = connectionPool.getConnection(environment.jdbcUrl.get) try { - val pstmt = new NamedParametersPreparedStatement(conn, tree) + val conn = connectionPool.getConnection(environment.jdbcUrl.get) try { - pstmt.parameterInfo(use.name) match { - case Right(typeInfo) => HoverResponse(Some(TypeCompletion(use.name, typeInfo.pgType.typeName))) - case Left(_) => HoverResponse(None) + val pstmt = new NamedParametersPreparedStatement(conn, tree) + try { + pstmt.parameterInfo(use.name) match { + case Right(typeInfo) => HoverResponse(Some(TypeCompletion(use.name, typeInfo.pgType.typeName))) + case Left(_) => HoverResponse(None) + } + } finally { + RawUtils.withSuppressNonFatalException(pstmt.close()) } + } catch { + case _: NamedParametersPreparedStatementException => HoverResponse(None) } finally { - RawUtils.withSuppressNonFatalException(pstmt.close()) + RawUtils.withSuppressNonFatalException(conn.close()) } } catch { - case _: NamedParametersPreparedStatementException => HoverResponse(None) - } finally { - RawUtils.withSuppressNonFatalException(conn.close()) + case ex: SQLException if isConnectionFailure(ex) => + logger.warn("SqlConnectionPool connection failure", ex) + HoverResponse(None) } } .getOrElse(HoverResponse(None)) @@ -372,21 +396,27 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends safeParse(source) match { case Left(errors) => ValidateResponse(errors) case Right(parsedTree) => - val conn = connectionPool.getConnection(environment.jdbcUrl.get) try { - val stmt = new NamedParametersPreparedStatement(conn, parsedTree) + val conn = connectionPool.getConnection(environment.jdbcUrl.get) try { - stmt.queryMetadata match { - case Right(_) => ValidateResponse(List.empty) - case Left(errors) => ValidateResponse(errors) + val stmt = new NamedParametersPreparedStatement(conn, parsedTree) + try { + stmt.queryMetadata match { + case Right(_) => ValidateResponse(List.empty) + case Left(errors) => ValidateResponse(errors) + } + } finally { + RawUtils.withSuppressNonFatalException(stmt.close()) } + } catch { + case e: NamedParametersPreparedStatementException => ValidateResponse(e.errors) } finally { - RawUtils.withSuppressNonFatalException(stmt.close()) + RawUtils.withSuppressNonFatalException(conn.close()) } } catch { - case e: NamedParametersPreparedStatementException => ValidateResponse(e.errors) - } finally { - RawUtils.withSuppressNonFatalException(conn.close()) + case ex: SQLException if isConnectionFailure(ex) => + logger.warn("SqlConnectionPool connection failure", ex) + ValidateResponse(List(treeError(parsedTree, ex.getMessage))) } } } catch { @@ -421,4 +451,8 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends rowAttrTypes.right.map(attrs => RawIterableType(RawRecordType(attrs.toVector, false, false), false, false)) } + private def isConnectionFailure(ex: SQLException) = { + val state = ex.getSQLState + state != null && state.startsWith("08") // connection exception, SqlConnectionPool is full + } } diff --git a/sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala b/sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala index 46992cd64..794834048 100644 --- a/sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala +++ b/sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala @@ -145,8 +145,8 @@ class SqlConnectionPool()(implicit settings: RawSettings) extends RawService wit // If no connection is currently available, just check if this specific db location is maxed out. if (maybeConn.isEmpty && conns.size >= maxConnectionsPerDb) { // No connection was available to borrow, and too many being used for this db location, so we cannot open - // any more! - throw new SQLException("too many connections active") + // any more! We throw with code 08000 (SQLSTATE "connection exception" in SQL standard). + throw new SQLException("too many connections active", "08000") } } @@ -165,7 +165,7 @@ class SqlConnectionPool()(implicit settings: RawSettings) extends RawService wit // We could not successfully release any connection, so bail out. if (getTotalActiveConnections() >= maxConnections) { - throw new SQLException("no connections available") + throw new SQLException("no connections available", "08000") } // Create a new connection. diff --git a/sql-client/src/main/scala/raw/client/sql/metadata/UserMetadataCache.scala b/sql-client/src/main/scala/raw/client/sql/metadata/UserMetadataCache.scala index 1a59f9d8d..c64e587ca 100644 --- a/sql-client/src/main/scala/raw/client/sql/metadata/UserMetadataCache.scala +++ b/sql-client/src/main/scala/raw/client/sql/metadata/UserMetadataCache.scala @@ -14,10 +14,13 @@ package raw.client.sql.metadata import com.google.common.cache.{CacheBuilder, CacheLoader} import com.typesafe.scalalogging.StrictLogging + import java.time.Duration import raw.client.sql.antlr4.{SqlIdentifierNode, SqlIdnNode, SqlProjNode} import raw.client.sql.{SqlConnectionPool, SqlIdentifier} +import java.sql.SQLException + case class IdentifierInfo(name: Seq[SqlIdentifier], tipe: String) /* This class is used to cache metadata info about the user's database. @@ -34,17 +37,23 @@ class UserMetadataCache(jdbcUrl: String, connectionPool: SqlConnectionPool, maxS private val wordCompletionCache = { val loader = new CacheLoader[Seq[SqlIdentifier], Seq[IdentifierInfo]]() { override def load(idns: Seq[SqlIdentifier]): Seq[IdentifierInfo] = { - val con = connectionPool.getConnection(jdbcUrl) try { - val query = idns.size match { - case 3 => WordSearchWithThreeItems - case 2 => WordSearchWithTwoItems - case 1 => WordSearchWithOneItem + val con = connectionPool.getConnection(jdbcUrl) + try { + val query = idns.size match { + case 3 => WordSearchWithThreeItems + case 2 => WordSearchWithTwoItems + case 1 => WordSearchWithOneItem + } + val tokens = idns.map(idn => if (idn.quoted) idn.value else idn.value.toLowerCase) + query.run(con, tokens) + } finally { + con.close() } - val tokens = idns.map(idn => if (idn.quoted) idn.value else idn.value.toLowerCase) - query.run(con, tokens) - } finally { - con.close() + } catch { + case ex: SQLException if isConnectionFailure(ex) => + logger.warn("SqlConnectionPool connection failure", ex) + Seq.empty } } } @@ -100,16 +109,22 @@ class UserMetadataCache(jdbcUrl: String, connectionPool: SqlConnectionPool, maxS private val dotCompletionCache = { val loader = new CacheLoader[Seq[SqlIdentifier], Seq[IdentifierInfo]]() { override def load(idns: Seq[SqlIdentifier]): Seq[IdentifierInfo] = { - val con = connectionPool.getConnection(jdbcUrl) try { - val query = idns.size match { - case 2 => DotSearchWithTwoItems - case 1 => DotSearchWithOneItem + val con = connectionPool.getConnection(jdbcUrl) + try { + val query = idns.size match { + case 2 => DotSearchWithTwoItems + case 1 => DotSearchWithOneItem + } + val tokens = idns.map(idn => if (idn.quoted) idn.value else idn.value.toLowerCase) + query.run(con, tokens) + } finally { + con.close() } - val tokens = idns.map(idn => if (idn.quoted) idn.value else idn.value.toLowerCase) - query.run(con, tokens) - } finally { - con.close() + } catch { + case ex: SQLException if isConnectionFailure(ex) => + logger.warn("SqlConnectionPool connection failure", ex) + Seq.empty } } } @@ -132,4 +147,8 @@ class UserMetadataCache(jdbcUrl: String, connectionPool: SqlConnectionPool, maxS dotCompletionCache.get(seq).map(i => (i.name, i.tipe)) } + private def isConnectionFailure(ex: SQLException) = { + val state = ex.getSQLState + state != null && state.startsWith("08") // connection exception, SqlConnectionPool is full + } } diff --git a/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala b/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala index f5136a7e7..832400c82 100644 --- a/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala +++ b/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala @@ -10,15 +10,6 @@ * licenses/APL.txt. */ -/* we probably need to make parameters optional and replace them by null when not specified. No default value - where does one get the credentials from? - let's try to find a decent library that deals with parameters? - since we run on Postgres we can use the Postgres JDBC driver exceptions and be more specific with error handling - - some messages show line and column numbers, we could use that to provide more precise error messages - - but it will be messed up if we don't account for question marks - one should return correct runtime/validation failures - */ - package raw.client.sql import com.dimafeng.testcontainers.{ForAllTestContainer, PostgreSQLContainer} diff --git a/sql-client/src/test/scala/raw/client/sql/TestSqlConnectionFailures.scala b/sql-client/src/test/scala/raw/client/sql/TestSqlConnectionFailures.scala new file mode 100644 index 000000000..da9d47246 --- /dev/null +++ b/sql-client/src/test/scala/raw/client/sql/TestSqlConnectionFailures.scala @@ -0,0 +1,581 @@ +/* + * Copyright 2024 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package raw.client.sql + +import com.dimafeng.testcontainers.{ForAllTestContainer, PostgreSQLContainer} +import org.scalatest.matchers.must.Matchers.be +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.testcontainers.utility.DockerImageName +import raw.client.api._ +import raw.utils._ + +import java.io.ByteArrayOutputStream +import java.sql.DriverManager +import java.util.concurrent.{Executors, TimeUnit} +import scala.io.Source + +class TestSqlConnectionFailures + extends RawTestSuite + with ForAllTestContainer + with SettingsTestContext + with TrainingWheelsContext { + + // The test suite triggers connection failures for both 'no connections + // available' (the pool can't open a new connection) and 'too many + // connections active' (the user runs too many queries in parallel) + // for all implemented calls: execute, validate, getProgramDescription, + // hover, dotCompletion, wordCompletion. For each we use sequential + // or parallel queries to exhaust the pool in some way and assert the + // failure is hit as expected. + + // Number of users to run with. This allows testing errors that occur + // when a single user exhausts their allocated share. + val nUsers = 3 + + // We run a test container emulating FDW. It has the example schema. + override val container: PostgreSQLContainer = PostgreSQLContainer( + dockerImageNameOverride = DockerImageName.parse("postgres:15-alpine") + ) + Class.forName("org.postgresql.Driver") + + private var users: Set[InteractiveUser] = _ + + private def jdbcUrl(user: AuthenticatedUser) = { + val dbPort = container.mappedPort(5432).toString + val dbName = user.uid + val username = container.username + val password = container.password + s"jdbc:postgresql://localhost:$dbPort/$dbName?user=$username&password=$password" + } + + override def beforeAll(): Unit = { + super.beforeAll() + + // For each user we create a specific database and load the example schema. + users = { + val items = for (i <- 1 to nUsers) yield InteractiveUser(Uid(s"db$i"), "fdw user", "email", Seq.empty) + items.toSet + } + + val exampleSchemaCreation = { + val resource = Source.fromResource("example.sql") + try { + resource.mkString + } finally { + resource.close() + } + } + + val conn = DriverManager.getConnection(container.jdbcUrl, container.username, container.password); + try { + val stmt = conn.createStatement() + for (user <- users) { + val r = stmt.executeUpdate(s"CREATE DATABASE ${user.uid.uid}") + assert(r == 0) + } + } finally { + conn.close() + } + + for (user <- users) { + val conn = DriverManager.getConnection(jdbcUrl(user), container.username, container.password); + try { + val stmt = conn.createStatement() + stmt.execute(exampleSchemaCreation) + } finally { + conn.close() + } + } + + } + + test("[lsp] enough connections in total") { _ => + // A single user calls LSP, while all others are running a long query. The user manages to get results + // because it could pick a connection. + val joe = users.head + val others = users.tail + property("raw.client.sql.pool.max-connections", s"$nUsers") // enough for everyone + property("raw.client.sql.pool.max-connections-per-db", s"1") + val compilerService = new SqlCompilerService() + val pool = Executors.newFixedThreadPool(others.size) + try { + // All other users run a long query which picks a connection for them + val futures = others.map(user => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + val results = futures.map(_.get(60, TimeUnit.SECONDS)) + results.foreach { + case ExecutionSuccess(complete) => complete shouldBe true + case r => fail(s"unexpected result $r") + } + // The user is able to get a connection to run all LSP calls. + // hover over 'example' picks a connection from the postgre metadata cache + val hoverResponse = runHover(compilerService, joe, "SELECT * FROM example.airports", Pos(1, 17)) + assert(hoverResponse.completion.contains(TypeCompletion("example", "schema"))) + // hover over ':id' picks a connection by asking the statement metadata to infer the type + val hoverResponse2 = + runHover(compilerService, joe, "SELECT * FROM example.airports WHERE :id = airport_id", Pos(1, 40)) + assert(hoverResponse2.completion.contains(TypeCompletion("id", "integer"))) + val wordCompletionResponse = runWordCompletion(compilerService, joe, "SELECT * FROM exa", "exa", Pos(1, 17)) + assert(wordCompletionResponse.completions.contains(LetBindCompletion("example", "schema"))) + val dotCompletionResponse = runDotCompletion(compilerService, joe, "SELECT * FROM example.", Pos(1, 22)) + assert(dotCompletionResponse.completions.collect { + case LetBindCompletion(name, _) => name + }.toSet === Set("airports", "trips", "machines")) + } finally { + pool.close() + compilerService.stop() + } + } + + test("[lsp] not enough connections in total") { _ => + // A single user calls LSP, while all others are running a long query. The user can't get results + // because the number of max-connections is set to others.size, they're all taken. + val joe = users.head + val others = users.tail + property("raw.client.sql.pool.max-connections", s"${others.size}") // one less than needed + property("raw.client.sql.pool.max-connections-per-db", s"1") + val compilerService = new SqlCompilerService() + val pool = Executors.newFixedThreadPool(others.size) + try { + val futures = others.map(user => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + val results = futures.map(_.get(60, TimeUnit.SECONDS)) + results.foreach { + case ExecutionSuccess(complete) => complete shouldBe true + case r => fail(s"unexpected result $r") + } + // hover returns nothing + val hoverResponse = runHover(compilerService, joe, "SELECT * FROM example.airports", Pos(1, 17)) + assert(hoverResponse.completion.isEmpty) + val hoverResponse2 = + runHover(compilerService, joe, "SELECT * FROM example.airports WHERE :id = airport_id", Pos(1, 40)) + assert(hoverResponse2.completion.isEmpty) + // we get no word completions + val wordCompletionResponse = runWordCompletion(compilerService, joe, "SELECT * FROM exa", "exa", Pos(1, 17)) + assert(wordCompletionResponse.completions.isEmpty) + // we get no dot completions + val dotCompletionResponse = runDotCompletion(compilerService, joe, "SELECT * FROM example.", Pos(1, 22)) + assert(dotCompletionResponse.completions.isEmpty) + } finally { + pool.close() + compilerService.stop() + } + } + + test("[lsp] enough connections per user") { _ => + // Again, a single user runs LSP calls while all (including itself) are running long connections. + // With two connections available, the single user manages to run all LSP calls. + val joe = users.head + property("raw.client.sql.pool.max-connections", s"${nUsers * 2}") + property("raw.client.sql.pool.max-connections-per-db", s"2") + val compilerService = new SqlCompilerService() + val pool = Executors.newFixedThreadPool(users.size) + try { + val futures = users.map(user => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + Thread.sleep(2000) // give some time to make sure they're all running + val hoverResponse = runHover(compilerService, joe, "SELECT * FROM example.airports", Pos(1, 17)) + assert(hoverResponse.completion.contains(TypeCompletion("example", "schema"))) + val hoverResponse2 = + runHover(compilerService, joe, "SELECT * FROM example.airports WHERE :id = airport_id", Pos(1, 40)) + assert(hoverResponse2.completion.contains(TypeCompletion("id", "integer"))) + val wordCompletionResponse = runWordCompletion(compilerService, joe, "SELECT * FROM exa", "exa", Pos(1, 17)) + assert(wordCompletionResponse.completions.contains(LetBindCompletion("example", "schema"))) + val dotCompletionResponse = runDotCompletion(compilerService, joe, "SELECT * FROM example.", Pos(1, 22)) + assert(dotCompletionResponse.completions.collect { + case LetBindCompletion(name, _) => name + }.toSet === Set("airports", "trips", "machines")) + val results = futures.map(_.get(60, TimeUnit.SECONDS)) + results.foreach { + case ExecutionSuccess(complete) => complete shouldBe true + case r => fail(s"unexpected result $r") + } + } finally { + pool.close() + compilerService.stop() + } + } + + test("[lsp] not enough connections per user") { _ => + // All users run a long query, including the one who then issues an LSP request. The LSP fails because + // only one connection per user is allowed. + val joe = users.head + property("raw.client.sql.pool.max-connections", s"${nUsers * 2}") // plenty + property("raw.client.sql.pool.max-connections-per-db", s"1") // only one per user + val compilerService = new SqlCompilerService() + val pool = Executors.newFixedThreadPool(users.size) + try { + // All users run a long query + val futures = users.map(user => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + Thread.sleep(2000) // give some time to make sure they're all running + // hover is None + val hoverResponse = runHover(compilerService, joe, "SELECT * FROM example.airports", Pos(1, 17)) + assert(hoverResponse.completion.isEmpty) + val hoverResponse2 = + runHover(compilerService, joe, "SELECT * FROM example.airports WHERE :id = airport_id", Pos(1, 40)) + assert(hoverResponse2.completion.isEmpty) + // word and dot completion return an empty list + val wordCompletionResponse = runWordCompletion(compilerService, joe, "SELECT * FROM exa", "exa", Pos(1, 17)) + assert(wordCompletionResponse.completions.isEmpty) + val dotCompletionResponse = runDotCompletion(compilerService, joe, "SELECT * FROM example.", Pos(1, 22)) + assert(dotCompletionResponse.completions.isEmpty) + val results = futures.map(_.get(60, TimeUnit.SECONDS)) + results.foreach { + case ExecutionSuccess(complete) => complete shouldBe true + case r => fail(s"unexpected result $r") + } + } finally { + pool.close() + compilerService.stop() + } + } + + test("[execute] enough connections in total") { _ => + /* Each user runs three times the same long query, one call at a time. The same connection is reused per user. + * This is confirmed by setting max-connections-per-db to 1 although several calls are performed per DB. + * In total, there's one connection per user. Setting max-connections to nUsers is working. + */ + val nCalls = 2 + property("raw.client.sql.pool.max-connections", s"$nUsers") + property("raw.client.sql.pool.max-connections-per-db", s"1") + val compilerService = new SqlCompilerService() + val iterations = 1 to nCalls + try { + val results = users + .map(user => user -> iterations.map(_ => runExecute(compilerService, user, longRunningQuery, 0))) + .toMap + for (userResults <- results.values; r <- userResults) r match { + case ExecutionSuccess(complete) => complete shouldBe true + case _ => fail(s"unexpected result $r") + } + } finally { + compilerService.stop() + } + } + + test("[execute] enough connections per user") { _ => + // We run `execute` _in parallel_ using a long query. Each user runs it `nCalls` times. So we have + // a total number of queries of nUsers x nCalls. We set max-connections to that value to be sure and + // set max-connections-per-db to nCalls so that all concurrent queries can run. + val nCalls = 3 + property("raw.client.sql.pool.max-connections", s"${nUsers * nCalls}") // enough total + property("raw.client.sql.pool.max-connections-per-db", s"$nCalls") // exactly what is needed per user + val compilerService = new SqlCompilerService() + val pool = Executors.newFixedThreadPool(nUsers * nCalls) + val iterations = 1 to nCalls + try { + val futures = users + .map(user => + user -> iterations.map(_ => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + ) + .toMap + val results = futures.mapValues(_.map(_.get(60, TimeUnit.SECONDS))) + for (userResults <- results.values; r <- userResults) r match { + case ExecutionSuccess(complete) => complete shouldBe true + case _ => fail(s"unexpected result $r") + } + } finally { + pool.close() + compilerService.stop() + } + } + + test("[execute] not enough connections") { _ => + /* Each user runs twice execute, one call at a time. The same connection can be reused per user. + * In total, there's one connection per user. Setting max-connections to nUsers - 1 triggers the + * expected failure. The number of errors hit should be positive (checked in the end) + */ + val nCalls = 2 + property("raw.client.sql.pool.max-connections", s"${nUsers - 1}") + property("raw.client.sql.pool.max-connections-per-db", s"1") + val compilerService = new SqlCompilerService() + val iterations = 1 to nCalls + try { + val results = users + .map(user => user -> iterations.map(_ => runExecute(compilerService, user, longRunningQuery, 0))) + .toMap + for (userResults <- results.values; r <- userResults) r match { + case ExecutionSuccess(complete) => complete shouldBe true + case ExecutionRuntimeFailure(error) => error shouldBe "no connections available" + case _ => fail(s"unexpected result $r") + } + val errorCount = results.values.map(_.count(_.isInstanceOf[ExecutionRuntimeFailure])).sum + errorCount should be > 0 + } finally { + compilerService.stop() + } + } + + test("[getProgramDescription] not enough connections") { _ => + /* Each user runs twice getProgramDescription, one call at a time. The same connection can be reused per user. + * In total, there's one connection per user. Setting max-connections to nUsers - 1 triggers the + * expected failure. The number of errors hit should be positive (checked in the end) + */ + val nCalls = 2 + property("raw.client.sql.pool.max-connections", s"${nUsers - 1}") + property("raw.client.sql.pool.max-connections-per-db", s"1") + val compilerService = new SqlCompilerService() + val iterations = 1 to nCalls + try { + val results = users + .map(user => user -> iterations.map(_ => runGetProgramDescription(compilerService, user, longValidateQuery))) + .toMap + for (userResults <- results.values; r <- userResults) r match { + case GetProgramDescriptionSuccess(_) => + case GetProgramDescriptionFailure(errors) => + errors.size shouldBe 1 + errors.head.message shouldBe "no connections available" + } + val errorCount = results.values.map(_.count(_.isInstanceOf[GetProgramDescriptionFailure])).sum + errorCount should be > 0 + } finally { + compilerService.stop() + } + } + + test("[validate] not enough connections") { _ => + /* Each user runs twice validate, one call at a time. The same connection can be reused per user. + * In total, there's one connection per user. Setting max-connections to nUsers - 1 triggers the + * expected failure. The number of errors hit should be positive (checked in the end) + */ + val nCalls = 2 + property("raw.client.sql.pool.max-connections", s"${nUsers - 1}") + property("raw.client.sql.pool.max-connections-per-db", s"1") + val compilerService = new SqlCompilerService() + val iterations = 1 to nCalls + try { + val results = users + .map(user => user -> iterations.map(_ => runValidate(compilerService, user, longValidateQuery))) + .toMap + for (userResults <- results.values; r <- userResults) r match { + case ValidateResponse(errors) if errors.isEmpty => + case ValidateResponse(errors) => + errors.size shouldBe 1 + errors.head.message shouldBe "no connections available" + } + val errorCount = results.values.map(_.count(_.messages.nonEmpty)).sum + errorCount should be > 0 + } finally { + compilerService.stop() + } + } + + test("[execute] not enough connections per user") { _ => + // We run `execute` in parallel using a long query. Each user runs it `nCalls` times. So we have + // a total number of queries of nUsers x nCalls. We set max-connections to that value to be sure but + // set max-connections-per-db to two so that all concurrent queries cannot all get a connection although + // max-connections would allow it. + val nCalls = 10 + property("raw.client.sql.pool.max-connections", s"${nUsers * nCalls}") // in principle enough + property("raw.client.sql.pool.max-connections-per-db", s"1") // but only few connections per user + val compilerService = new SqlCompilerService() + val pool = Executors.newFixedThreadPool(nUsers * nCalls) + val iterations = 1 to nCalls + try { + val futures = users + .map(user => + user -> iterations.map(_ => pool.submit(() => runExecute(compilerService, user, longRunningQuery, 5))) + ) + .toMap + val results = futures.mapValues(_.map(_.get(60, TimeUnit.SECONDS))) + for (userResults <- results.values; r <- userResults) r match { + case ExecutionSuccess(complete) => complete shouldBe true + case ExecutionRuntimeFailure(error) => error shouldBe "too many connections active" + case _ => fail(s"unexpected result $r") + } + val errorCount = results.values.map(_.count(_.isInstanceOf[ExecutionRuntimeFailure])).sum + errorCount should be > 0 + } finally { + pool.close() + compilerService.stop() + } + } + + test("[getProgramDescription] not enough connections per user") { _ => + // We run `getProgramDescription` in parallel using a long query. Each user runs it `nCalls` times. So we have + // a total number of queries of nUsers x nCalls. We set max-connections to that value to be sure but + // set max-connections-per-db to two so that all concurrent queries cannot all get a connection although + // max-connections would allow it. + val nCalls = 10 + property("raw.client.sql.pool.max-connections", s"${nUsers * nCalls}") // in principle enough + property("raw.client.sql.pool.max-connections-per-db", s"1") // but only few connections per user + val compilerService = new SqlCompilerService() + val pool = Executors.newFixedThreadPool(nUsers * nCalls) + val iterations = 1 to nCalls + try { + val futures = users + .map(user => + user -> iterations.map(_ => + pool.submit(() => runGetProgramDescription(compilerService, user, longValidateQuery)) + ) + ) + .toMap + val results = futures.mapValues(_.map(_.get(60, TimeUnit.SECONDS))) + for (userResults <- results.values; r <- userResults) r match { + case GetProgramDescriptionSuccess(_) => + case GetProgramDescriptionFailure(errors) => + errors.size shouldBe 1 + errors.head.message shouldBe "too many connections active" + } + val errorCount = results.values.map(_.count(_.isInstanceOf[GetProgramDescriptionFailure])).sum + errorCount should be > 0 + } finally { + pool.close() + compilerService.stop() + } + } + + test("[validate] not enough connections per user") { _ => + // We run `validate` in parallel using a long query. Each user runs it `nCalls` times. So we have + // a total number of queries of nUsers x nCalls. We set max-connections to that value to be sure but + // set max-connections-per-db to two so that all concurrent queries cannot all get a connection although + // max-connections would allow it. + val nCalls = 10 + property("raw.client.sql.pool.max-connections", s"${nUsers * nCalls}") // in principle enough + property("raw.client.sql.pool.max-connections-per-db", s"1") // but only few connections per user + val compilerService = new SqlCompilerService() + val pool = Executors.newFixedThreadPool(nUsers * nCalls) + val iterations = 1 to nCalls + try { + val futures = users + .map(user => + user -> iterations.map(_ => pool.submit(() => runValidate(compilerService, user, longValidateQuery))) + ) + .toMap + val results = futures.mapValues(_.map(_.get(60, TimeUnit.SECONDS))) + for (userResults <- results.values; r <- userResults) r match { + case ValidateResponse(errors) if errors.isEmpty => + case ValidateResponse(errors) => + errors.size shouldBe 1 + errors.head.message shouldBe "too many connections active" + } + val errorCount = results.values.map(_.count(_.messages.nonEmpty)).sum + errorCount should be > 0 + } finally { + pool.close() + compilerService.stop() + } + } + + private def runExecute( + compilerService: CompilerService, + user: AuthenticatedUser, + code: String, + arg: Int + ): ExecutionResponse = { + val env = ProgramEnvironment( + user, + Some(Array("arg" -> RawInt(arg))), + Set.empty, + Map("output-format" -> "json"), + jdbcUrl = Some(jdbcUrl(user)) + ) + val baos = new ByteArrayOutputStream() + try { + compilerService.execute(code, env, None, baos) + } finally { + baos.close() + } + } + + private def runHover( + compilerService: CompilerService, + user: AuthenticatedUser, + code: String, + pos: Pos + ): HoverResponse = { + val env = ProgramEnvironment( + user, + None, + Set.empty, + Map("output-format" -> "json"), + jdbcUrl = Some(jdbcUrl(user)) + ) + compilerService.hover(code, env, pos) + } + + private def runWordCompletion( + compilerService: CompilerService, + user: AuthenticatedUser, + code: String, + prefix: String, + pos: Pos + ): AutoCompleteResponse = { + val env = ProgramEnvironment( + user, + None, + Set.empty, + Map("output-format" -> "json"), + jdbcUrl = Some(jdbcUrl(user)) + ) + compilerService.wordAutoComplete(code, env, prefix, pos) + } + + private def runDotCompletion( + compilerService: CompilerService, + user: AuthenticatedUser, + code: String, + pos: Pos + ): AutoCompleteResponse = { + val env = ProgramEnvironment( + user, + None, + Set.empty, + Map("output-format" -> "json"), + jdbcUrl = Some(jdbcUrl(user)) + ) + compilerService.dotAutoComplete(code, env, pos) + } + + private def runGetProgramDescription( + compilerService: CompilerService, + user: AuthenticatedUser, + code: String + ): GetProgramDescriptionResponse = { + val env = ProgramEnvironment( + user, + None, + Set.empty, + Map("output-format" -> "json"), + jdbcUrl = Some(jdbcUrl(user)) + ) + compilerService.getProgramDescription(code, env) + } + + private def runValidate( + compilerService: CompilerService, + user: AuthenticatedUser, + code: String + ): ValidateResponse = { + val env = ProgramEnvironment( + user, + None, + Set.empty, + Map("output-format" -> "json"), + jdbcUrl = Some(jdbcUrl(user)) + ) + compilerService.validate(code, env) + } + + // it sleeps for 'arg' seconds (default 1) + private val longRunningQuery = """-- @default arg 1 + |SELECT CAST(pg_sleep(:arg) AS VARCHAR)""".stripMargin + + // it runs fast but its default parameter value takes long to compute. That permits to + // slow down validation calls. + private val longValidateQuery = """-- @default arg SELECT 1 WHERE pg_sleep(5) IS NOT NULL + |SELECT :arg""".stripMargin + + override def afterAll(): Unit = { + container.stop() + super.afterAll() + } + +}