Skip to content

Commit

Permalink
SQL client uses test containers + JDBC URL added to ProgramEnvironment (
Browse files Browse the repository at this point in the history
  • Loading branch information
bgaidioz authored Jul 1, 2024
1 parent 2ae3f90 commit 3ccc2b4
Show file tree
Hide file tree
Showing 11 changed files with 8,406 additions and 283 deletions.
6 changes: 3 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ lazy val sqlParser = (project in file("sql-parser"))
lazy val sqlClient = (project in file("sql-client"))
.dependsOn(
client % "compile->compile;test->test",
snapiFrontend % "compile->compile;test->test",
sqlParser % "compile->compile;test->test",
sources % "compile->compile;test->test",
)
.settings(
commonSettings,
Expand All @@ -283,7 +281,9 @@ lazy val sqlClient = (project in file("sql-client"))
libraryDependencies ++= Seq(
kiama,
postgresqlDeps,
hikariCP
hikariCP,
"com.dimafeng" %% "testcontainers-scala-scalatest" % "0.41.3" % Test,
"com.dimafeng" %% "testcontainers-scala-postgresql" % "0.41.3" % Test
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ final case class ProgramEnvironment(
maybeArguments: Option[Array[(String, RawValue)]],
scopes: Set[String],
options: Map[String, String],
maybeTraceId: Option[String] = None
maybeTraceId: Option[String] = None,
jdbcUrl: Option[String] = None
)
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ public RawContext(RawLanguage language, Env env) {
// Arguments are unused by the runtime in case of Truffle.
Option<Tuple2<String, RawValue>[]> maybeArguments = Option.empty();
this.programEnvironment =
new ProgramEnvironment(this.user, maybeArguments, scalaScopes, scalaOptions, maybeTraceId);
new ProgramEnvironment(
this.user, maybeArguments, scalaScopes, scalaOptions, maybeTraceId, Option.empty());

// The function registry holds snapi methods (top level functions). It is the data
// structure that is used to extract a ref to a function from a piece of execute snapi.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

package raw.creds.client

import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import raw.creds.api._
import raw.creds.client.ClientCredentialsService.SERVER_ADDRESS
import raw.rest.client.APIException
Expand Down
2 changes: 0 additions & 2 deletions sql-client/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
requires raw.sql.parser;
requires java.sql;
requires com.zaxxer.hikari;
requires raw.snapi.frontend;
requires raw.sources;

provides raw.client.api.CompilerServiceBuilder with
raw.client.sql.SqlCompilerServiceBuilder;
Expand Down
40 changes: 24 additions & 16 deletions sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,32 @@ import raw.client.api._
import raw.client.sql.antlr4.{ParseProgramResult, RawSqlSyntaxAnalyzer, SqlIdnNode, SqlParamUseNode}
import raw.client.sql.metadata.UserMetadataCache
import raw.client.sql.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter}
import raw.creds.api.CredentialsServiceProvider
import raw.utils.{AuthenticatedUser, RawSettings, RawUtils}
import raw.utils.{RawSettings, RawUtils}

import java.io.{IOException, OutputStream}
import java.sql.ResultSet
import scala.util.control.NonFatal

/**
* A CompilerService implementation for the SQL (Postgres) language.
*
* @param settings The configuration settings for the SQL compiler.
*/
class SqlCompilerService()(implicit protected val settings: RawSettings) extends CompilerService {

private val credentials = CredentialsServiceProvider()

private val connectionPool = new SqlConnectionPool(credentials)
private val connectionPool = new SqlConnectionPool()

// A short lived database metadata (schema/table/column names) indexed by JDBC URL.
private val metadataBrowsers = {
val loader = new CacheLoader[AuthenticatedUser, UserMetadataCache] {
override def load(user: AuthenticatedUser): UserMetadataCache =
new UserMetadataCache(user, connectionPool, settings)
val maxSize = settings.getInt("raw.client.sql.metadata-cache.max-matches")
val expiry = settings.getDuration("raw.client.sql.metadata-cache.match-validity")
val loader = new CacheLoader[String, UserMetadataCache] {
override def load(jdbcUrl: String): UserMetadataCache = new UserMetadataCache(
jdbcUrl,
connectionPool,
maxSize = maxSize,
expiry = expiry
)
}
CacheBuilder
.newBuilder()
Expand Down Expand Up @@ -77,7 +86,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
safeParse(source) match {
case Left(errors) => GetProgramDescriptionFailure(errors)
case Right(parsedTree) =>
val conn = connectionPool.getConnection(environment.user)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val stmt = new NamedParametersPreparedStatement(conn, parsedTree)
val description = stmt.queryMetadata match {
Expand Down Expand Up @@ -148,7 +157,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
safeParse(source) match {
case Left(errors) => ExecutionValidationFailure(errors)
case Right(parsedTree) =>
val conn = connectionPool.getConnection(environment.user)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val pstmt = new NamedParametersPreparedStatement(conn, parsedTree, environment.scopes)
try {
Expand Down Expand Up @@ -245,7 +254,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
// So we call the identifier with +1 column
analyzer.identifierUnder(Pos(position.line, position.column + 1)) match {
case Some(idn: SqlIdnNode) =>
val metadataBrowser = metadataBrowsers.get(environment.user)
val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get)
val matches = metadataBrowser.getDotCompletionMatches(idn)
val collectedValues = matches.collect {
case (idns, tipe) =>
Expand Down Expand Up @@ -278,7 +287,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
logger.debug(s"idn $item")
val matches: Seq[Completion] = item match {
case Some(idn: SqlIdnNode) =>
val metadataBrowser = metadataBrowsers.get(environment.user)
val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get)
val matches = metadataBrowser.getWordCompletionMatches(idn)
matches.collect { case (idns, value) => LetBindCompletion(idns.last.value, value) }
case Some(use: SqlParamUseNode) => tree.params.collect {
Expand All @@ -302,13 +311,13 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
.identifierUnder(position)
.map {
case identifier: SqlIdnNode =>
val metadataBrowser = metadataBrowsers.get(environment.user)
val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get)
val matches = metadataBrowser.getWordCompletionMatches(identifier)
matches.headOption
.map { case (names, tipe) => HoverResponse(Some(TypeCompletion(formatIdns(names), tipe))) }
.getOrElse(HoverResponse(None))
case use: SqlParamUseNode =>
val conn = connectionPool.getConnection(environment.user)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val pstmt = new NamedParametersPreparedStatement(conn, tree)
try {
Expand Down Expand Up @@ -361,7 +370,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
safeParse(source) match {
case Left(errors) => ValidateResponse(errors)
case Right(parsedTree) =>
val conn = connectionPool.getConnection(environment.user)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val stmt = new NamedParametersPreparedStatement(conn, parsedTree)
try {
Expand Down Expand Up @@ -395,7 +404,6 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends

override def doStop(): Unit = {
connectionPool.stop()
credentials.stop()
}

private def pgRowTypeToIterableType(rowType: PostgresRowType): Either[Seq[String], RawIterableType] = {
Expand Down
33 changes: 8 additions & 25 deletions sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,28 @@ package raw.client.sql
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache, RemovalNotification}
import com.typesafe.scalalogging.StrictLogging
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
import raw.creds.api.CredentialsService
import raw.utils.{AuthenticatedUser, RawService, RawSettings, RawUtils}
import raw.utils.{RawService, RawSettings, RawUtils}

import java.sql.SQLException
import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit}
import scala.collection.mutable

class SqlConnectionPool(credentialsService: CredentialsService)(implicit settings: RawSettings)
extends RawService
with StrictLogging {
class SqlConnectionPool()(implicit settings: RawSettings) extends RawService with StrictLogging {

private val dbHost = settings.getString("raw.creds.jdbc.fdw.host")
private val dbPort = settings.getInt("raw.creds.jdbc.fdw.port")
private val readOnlyUser = settings.getString("raw.creds.jdbc.fdw.user")
private val password = settings.getString("raw.creds.jdbc.fdw.password")
private val maxConnections = settings.getInt("raw.client.sql.pool.max-connections")
private val idleTimeout = settings.getDuration("raw.client.sql.pool.idle-timeout", TimeUnit.MILLISECONDS)
private val maxLifetime = settings.getDuration("raw.client.sql.pool.max-lifetime", TimeUnit.MILLISECONDS)
private val connectionTimeout = settings.getDuration("raw.client.sql.pool.connection-timeout", TimeUnit.MILLISECONDS)

private val poolGarbageCollectionPeriod = settings.getDuration("raw.client.sql.pool.gc-period")
private val poolsToDelete = new ConcurrentHashMap[String, HikariDataSource]()
private val garbageCollectScheduller =
private val garbageCollectScheduler =
Executors.newSingleThreadScheduledExecutor(RawUtils.newThreadFactory("sql-connection-pool-gc"))

// Periodically check for idle pools and close them
// If the hikari pool in the cache expires and still has active connections, we will move it to the poolsToDelete map
// Then we delete it later when the active connections are 0 (i.e. long queries are done and the pool is not needed anymore)
garbageCollectScheduller.scheduleAtFixedRate(
garbageCollectScheduler.scheduleAtFixedRate(
() => {
val urlsToRemove = mutable.ArrayBuffer[String]()
poolsToDelete.forEach((url, pool) => {
Expand Down Expand Up @@ -72,8 +65,6 @@ class SqlConnectionPool(credentialsService: CredentialsService)(implicit setting
config.setIdleTimeout(idleTimeout)
config.setMaxLifetime(maxLifetime)
config.setConnectionTimeout(connectionTimeout)
config.setUsername(readOnlyUser)
config.setPassword(password)
val pool = new HikariDataSource(config)
pool
}
Expand All @@ -95,24 +86,16 @@ class SqlConnectionPool(credentialsService: CredentialsService)(implicit setting
.build(dbCacheLoader)

@throws[SQLException]
def getConnection(user: AuthenticatedUser): java.sql.Connection = {
val db = settings.getStringOpt(s"raw.creds.jdbc.${user.uid.uid}.db").getOrElse(credentialsService.getUserDb(user))
val maybeSchema = settings.getStringOpt(s"raw.creds.jdbc.${user.uid.uid}.schema")

val url = maybeSchema match {
case Some(schema) => s"jdbc:postgresql://$dbHost:$dbPort/$db?currentSchema=$schema"
case None => s"jdbc:postgresql://$dbHost:$dbPort/$db"
}

dbCache.get(url).getConnection()
def getConnection(jdbcUrl: String): java.sql.Connection = {
dbCache.get(jdbcUrl).getConnection()
}

override def doStop(): Unit = {
dbCache.asMap().values().forEach { pool =>
logger.info(s"Shutting down SQL connection pool for database ${pool.getJdbcUrl}")
RawUtils.withSuppressNonFatalException(pool.close())
}
garbageCollectScheduller.shutdown()
garbageCollectScheduller.awaitTermination(5, TimeUnit.SECONDS)
garbageCollectScheduler.shutdown()
garbageCollectScheduler.awaitTermination(5, TimeUnit.SECONDS)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ package raw.client.sql.metadata

import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.typesafe.scalalogging.StrictLogging
import java.time.Duration
import raw.client.sql.antlr4.{SqlIdentifierNode, SqlIdnNode, SqlProjNode}
import raw.client.sql.{SqlConnectionPool, SqlIdentifier}
import raw.utils.{AuthenticatedUser, RawSettings}

case class IdentifierInfo(name: Seq[SqlIdentifier], tipe: String)

Expand All @@ -28,13 +28,13 @@ case class IdentifierInfo(name: Seq[SqlIdentifier], tipe: String)
* Entries in these two caches are fairly short-lived. They get deleted (and recomputed if needed) after a few seconds
* so that the user can see new schemas, tables or columns that have been created in the database.
*/
class UserMetadataCache(user: AuthenticatedUser, connectionPool: SqlConnectionPool, settings: RawSettings)
class UserMetadataCache(jdbcUrl: String, connectionPool: SqlConnectionPool, maxSize: Int, expiry: Duration)
extends StrictLogging {

private val wordCompletionCache = {
val loader = new CacheLoader[Seq[SqlIdentifier], Seq[IdentifierInfo]]() {
override def load(idns: Seq[SqlIdentifier]): Seq[IdentifierInfo] = {
val con = connectionPool.getConnection(user)
val con = connectionPool.getConnection(jdbcUrl)
try {
val query = idns.size match {
case 3 => WordSearchWithThreeItems
Expand All @@ -50,8 +50,8 @@ class UserMetadataCache(user: AuthenticatedUser, connectionPool: SqlConnectionPo
}
CacheBuilder
.newBuilder()
.maximumSize(settings.getInt("raw.client.sql.metadata-cache.max-matches"))
.expireAfterWrite(settings.getDuration("raw.client.sql.metadata-cache.match-validity"))
.maximumSize(maxSize)
.expireAfterWrite(expiry)
.build(loader)
}

Expand Down Expand Up @@ -100,7 +100,7 @@ class UserMetadataCache(user: AuthenticatedUser, connectionPool: SqlConnectionPo
private val dotCompletionCache = {
val loader = new CacheLoader[Seq[SqlIdentifier], Seq[IdentifierInfo]]() {
override def load(idns: Seq[SqlIdentifier]): Seq[IdentifierInfo] = {
val con = connectionPool.getConnection(user)
val con = connectionPool.getConnection(jdbcUrl)
try {
val query = idns.size match {
case 2 => DotSearchWithTwoItems
Expand All @@ -115,8 +115,8 @@ class UserMetadataCache(user: AuthenticatedUser, connectionPool: SqlConnectionPo
}
CacheBuilder
.newBuilder()
.maximumSize(settings.getInt("raw.client.sql.metadata-cache.max-matches"))
.expireAfterWrite(settings.getDuration("raw.client.sql.metadata-cache.match-validity"))
.maximumSize(maxSize)
.expireAfterWrite(expiry)
.build(loader)
}

Expand Down
Loading

0 comments on commit 3ccc2b4

Please sign in to comment.