Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SQL client uses test containers + JDBC URL added to ProgramEnvironment #452

Merged
merged 4 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -272,9 +272,7 @@ lazy val sqlParser = (project in file("sql-parser"))
lazy val sqlClient = (project in file("sql-client"))
.dependsOn(
client % "compile->compile;test->test",
snapiFrontend % "compile->compile;test->test",
sqlParser % "compile->compile;test->test",
sources % "compile->compile;test->test",
)
.settings(
commonSettings,
Expand All @@ -283,7 +281,9 @@ lazy val sqlClient = (project in file("sql-client"))
libraryDependencies ++= Seq(
kiama,
postgresqlDeps,
hikariCP
hikariCP,
"com.dimafeng" %% "testcontainers-scala-scalatest" % "0.41.3" % Test,
"com.dimafeng" %% "testcontainers-scala-postgresql" % "0.41.3" % Test
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ final case class ProgramEnvironment(
maybeArguments: Option[Array[(String, RawValue)]],
scopes: Set[String],
options: Map[String, String],
maybeTraceId: Option[String] = None
maybeTraceId: Option[String] = None,
jdbcUrl: Option[String] = None
)
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ public RawContext(RawLanguage language, Env env) {
// Arguments are unused by the runtime in case of Truffle.
Option<Tuple2<String, RawValue>[]> maybeArguments = Option.empty();
this.programEnvironment =
new ProgramEnvironment(this.user, maybeArguments, scalaScopes, scalaOptions, maybeTraceId);
new ProgramEnvironment(
this.user, maybeArguments, scalaScopes, scalaOptions, maybeTraceId, Option.empty());

// The function registry holds snapi methods (top level functions). It is the data
// structure that is used to extract a ref to a function from a piece of execute snapi.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

package raw.creds.client

import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}
import raw.creds.api._
import raw.creds.client.ClientCredentialsService.SERVER_ADDRESS
import raw.rest.client.APIException
Expand Down
2 changes: 0 additions & 2 deletions sql-client/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
requires raw.sql.parser;
requires java.sql;
requires com.zaxxer.hikari;
requires raw.snapi.frontend;
requires raw.sources;

provides raw.client.api.CompilerServiceBuilder with
raw.client.sql.SqlCompilerServiceBuilder;
Expand Down
40 changes: 24 additions & 16 deletions sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,32 @@ import raw.client.api._
import raw.client.sql.antlr4.{ParseProgramResult, RawSqlSyntaxAnalyzer, SqlIdnNode, SqlParamUseNode}
import raw.client.sql.metadata.UserMetadataCache
import raw.client.sql.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter}
import raw.creds.api.CredentialsServiceProvider
import raw.utils.{AuthenticatedUser, RawSettings, RawUtils}
import raw.utils.{RawSettings, RawUtils}

import java.io.{IOException, OutputStream}
import java.sql.ResultSet
import scala.util.control.NonFatal

/**
* A CompilerService implementation for the SQL (Postgres) language.
*
* @param settings The configuration settings for the SQL compiler.
*/
class SqlCompilerService()(implicit protected val settings: RawSettings) extends CompilerService {

private val credentials = CredentialsServiceProvider()

private val connectionPool = new SqlConnectionPool(credentials)
private val connectionPool = new SqlConnectionPool()

// A short lived database metadata (schema/table/column names) indexed by JDBC URL.
private val metadataBrowsers = {
val loader = new CacheLoader[AuthenticatedUser, UserMetadataCache] {
override def load(user: AuthenticatedUser): UserMetadataCache =
new UserMetadataCache(user, connectionPool, settings)
val maxSize = settings.getInt("raw.client.sql.metadata-cache.max-matches")
val expiry = settings.getDuration("raw.client.sql.metadata-cache.match-validity")
val loader = new CacheLoader[String, UserMetadataCache] {
override def load(jdbcUrl: String): UserMetadataCache = new UserMetadataCache(
jdbcUrl,
connectionPool,
maxSize = maxSize,
expiry = expiry
)
}
CacheBuilder
.newBuilder()
Expand Down Expand Up @@ -77,7 +86,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
safeParse(source) match {
case Left(errors) => GetProgramDescriptionFailure(errors)
case Right(parsedTree) =>
val conn = connectionPool.getConnection(environment.user)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val stmt = new NamedParametersPreparedStatement(conn, parsedTree)
val description = stmt.queryMetadata match {
Expand Down Expand Up @@ -148,7 +157,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
safeParse(source) match {
case Left(errors) => ExecutionValidationFailure(errors)
case Right(parsedTree) =>
val conn = connectionPool.getConnection(environment.user)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val pstmt = new NamedParametersPreparedStatement(conn, parsedTree, environment.scopes)
try {
Expand Down Expand Up @@ -245,7 +254,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
// So we call the identifier with +1 column
analyzer.identifierUnder(Pos(position.line, position.column + 1)) match {
case Some(idn: SqlIdnNode) =>
val metadataBrowser = metadataBrowsers.get(environment.user)
val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get)
val matches = metadataBrowser.getDotCompletionMatches(idn)
val collectedValues = matches.collect {
case (idns, tipe) =>
Expand Down Expand Up @@ -278,7 +287,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
logger.debug(s"idn $item")
val matches: Seq[Completion] = item match {
case Some(idn: SqlIdnNode) =>
val metadataBrowser = metadataBrowsers.get(environment.user)
val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get)
val matches = metadataBrowser.getWordCompletionMatches(idn)
matches.collect { case (idns, value) => LetBindCompletion(idns.last.value, value) }
case Some(use: SqlParamUseNode) => tree.params.collect {
Expand All @@ -302,13 +311,13 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
.identifierUnder(position)
.map {
case identifier: SqlIdnNode =>
val metadataBrowser = metadataBrowsers.get(environment.user)
val metadataBrowser = metadataBrowsers.get(environment.jdbcUrl.get)
val matches = metadataBrowser.getWordCompletionMatches(identifier)
matches.headOption
.map { case (names, tipe) => HoverResponse(Some(TypeCompletion(formatIdns(names), tipe))) }
.getOrElse(HoverResponse(None))
case use: SqlParamUseNode =>
val conn = connectionPool.getConnection(environment.user)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val pstmt = new NamedParametersPreparedStatement(conn, tree)
try {
Expand Down Expand Up @@ -361,7 +370,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends
safeParse(source) match {
case Left(errors) => ValidateResponse(errors)
case Right(parsedTree) =>
val conn = connectionPool.getConnection(environment.user)
val conn = connectionPool.getConnection(environment.jdbcUrl.get)
try {
val stmt = new NamedParametersPreparedStatement(conn, parsedTree)
try {
Expand Down Expand Up @@ -395,7 +404,6 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends

override def doStop(): Unit = {
connectionPool.stop()
credentials.stop()
}

private def pgRowTypeToIterableType(rowType: PostgresRowType): Either[Seq[String], RawIterableType] = {
Expand Down
33 changes: 8 additions & 25 deletions sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,28 @@ package raw.client.sql
import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache, RemovalNotification}
import com.typesafe.scalalogging.StrictLogging
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
import raw.creds.api.CredentialsService
import raw.utils.{AuthenticatedUser, RawService, RawSettings, RawUtils}
import raw.utils.{RawService, RawSettings, RawUtils}

import java.sql.SQLException
import java.util.concurrent.{ConcurrentHashMap, Executors, TimeUnit}
import scala.collection.mutable

class SqlConnectionPool(credentialsService: CredentialsService)(implicit settings: RawSettings)
extends RawService
with StrictLogging {
class SqlConnectionPool()(implicit settings: RawSettings) extends RawService with StrictLogging {

private val dbHost = settings.getString("raw.creds.jdbc.fdw.host")
private val dbPort = settings.getInt("raw.creds.jdbc.fdw.port")
private val readOnlyUser = settings.getString("raw.creds.jdbc.fdw.user")
private val password = settings.getString("raw.creds.jdbc.fdw.password")
private val maxConnections = settings.getInt("raw.client.sql.pool.max-connections")
private val idleTimeout = settings.getDuration("raw.client.sql.pool.idle-timeout", TimeUnit.MILLISECONDS)
private val maxLifetime = settings.getDuration("raw.client.sql.pool.max-lifetime", TimeUnit.MILLISECONDS)
private val connectionTimeout = settings.getDuration("raw.client.sql.pool.connection-timeout", TimeUnit.MILLISECONDS)

private val poolGarbageCollectionPeriod = settings.getDuration("raw.client.sql.pool.gc-period")
private val poolsToDelete = new ConcurrentHashMap[String, HikariDataSource]()
private val garbageCollectScheduller =
private val garbageCollectScheduler =
Executors.newSingleThreadScheduledExecutor(RawUtils.newThreadFactory("sql-connection-pool-gc"))

// Periodically check for idle pools and close them
// If the hikari pool in the cache expires and still has active connections, we will move it to the poolsToDelete map
// Then we delete it later when the active connections are 0 (i.e. long queries are done and the pool is not needed anymore)
garbageCollectScheduller.scheduleAtFixedRate(
garbageCollectScheduler.scheduleAtFixedRate(
() => {
val urlsToRemove = mutable.ArrayBuffer[String]()
poolsToDelete.forEach((url, pool) => {
Expand Down Expand Up @@ -72,8 +65,6 @@ class SqlConnectionPool(credentialsService: CredentialsService)(implicit setting
config.setIdleTimeout(idleTimeout)
config.setMaxLifetime(maxLifetime)
config.setConnectionTimeout(connectionTimeout)
config.setUsername(readOnlyUser)
config.setPassword(password)
val pool = new HikariDataSource(config)
pool
}
Expand All @@ -95,24 +86,16 @@ class SqlConnectionPool(credentialsService: CredentialsService)(implicit setting
.build(dbCacheLoader)

@throws[SQLException]
def getConnection(user: AuthenticatedUser): java.sql.Connection = {
val db = settings.getStringOpt(s"raw.creds.jdbc.${user.uid.uid}.db").getOrElse(credentialsService.getUserDb(user))
val maybeSchema = settings.getStringOpt(s"raw.creds.jdbc.${user.uid.uid}.schema")

val url = maybeSchema match {
case Some(schema) => s"jdbc:postgresql://$dbHost:$dbPort/$db?currentSchema=$schema"
case None => s"jdbc:postgresql://$dbHost:$dbPort/$db"
}

dbCache.get(url).getConnection()
def getConnection(jdbcUrl: String): java.sql.Connection = {
dbCache.get(jdbcUrl).getConnection()
}

override def doStop(): Unit = {
dbCache.asMap().values().forEach { pool =>
logger.info(s"Shutting down SQL connection pool for database ${pool.getJdbcUrl}")
RawUtils.withSuppressNonFatalException(pool.close())
}
garbageCollectScheduller.shutdown()
garbageCollectScheduller.awaitTermination(5, TimeUnit.SECONDS)
garbageCollectScheduler.shutdown()
garbageCollectScheduler.awaitTermination(5, TimeUnit.SECONDS)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ package raw.client.sql.metadata

import com.google.common.cache.{CacheBuilder, CacheLoader}
import com.typesafe.scalalogging.StrictLogging
import java.time.Duration
import raw.client.sql.antlr4.{SqlIdentifierNode, SqlIdnNode, SqlProjNode}
import raw.client.sql.{SqlConnectionPool, SqlIdentifier}
import raw.utils.{AuthenticatedUser, RawSettings}

case class IdentifierInfo(name: Seq[SqlIdentifier], tipe: String)

Expand All @@ -28,13 +28,13 @@ case class IdentifierInfo(name: Seq[SqlIdentifier], tipe: String)
* Entries in these two caches are fairly short-lived. They get deleted (and recomputed if needed) after a few seconds
* so that the user can see new schemas, tables or columns that have been created in the database.
*/
class UserMetadataCache(user: AuthenticatedUser, connectionPool: SqlConnectionPool, settings: RawSettings)
class UserMetadataCache(jdbcUrl: String, connectionPool: SqlConnectionPool, maxSize: Int, expiry: Duration)
extends StrictLogging {

private val wordCompletionCache = {
val loader = new CacheLoader[Seq[SqlIdentifier], Seq[IdentifierInfo]]() {
override def load(idns: Seq[SqlIdentifier]): Seq[IdentifierInfo] = {
val con = connectionPool.getConnection(user)
val con = connectionPool.getConnection(jdbcUrl)
try {
val query = idns.size match {
case 3 => WordSearchWithThreeItems
Expand All @@ -50,8 +50,8 @@ class UserMetadataCache(user: AuthenticatedUser, connectionPool: SqlConnectionPo
}
CacheBuilder
.newBuilder()
.maximumSize(settings.getInt("raw.client.sql.metadata-cache.max-matches"))
.expireAfterWrite(settings.getDuration("raw.client.sql.metadata-cache.match-validity"))
.maximumSize(maxSize)
.expireAfterWrite(expiry)
.build(loader)
}

Expand Down Expand Up @@ -100,7 +100,7 @@ class UserMetadataCache(user: AuthenticatedUser, connectionPool: SqlConnectionPo
private val dotCompletionCache = {
val loader = new CacheLoader[Seq[SqlIdentifier], Seq[IdentifierInfo]]() {
override def load(idns: Seq[SqlIdentifier]): Seq[IdentifierInfo] = {
val con = connectionPool.getConnection(user)
val con = connectionPool.getConnection(jdbcUrl)
try {
val query = idns.size match {
case 2 => DotSearchWithTwoItems
Expand All @@ -115,8 +115,8 @@ class UserMetadataCache(user: AuthenticatedUser, connectionPool: SqlConnectionPo
}
CacheBuilder
.newBuilder()
.maximumSize(settings.getInt("raw.client.sql.metadata-cache.max-matches"))
.expireAfterWrite(settings.getDuration("raw.client.sql.metadata-cache.match-validity"))
.maximumSize(maxSize)
.expireAfterWrite(expiry)
.build(loader)
}

Expand Down
Loading