Skip to content

Commit

Permalink
Patch to RD-10677 (#389)
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelbranco80 authored Mar 25, 2024
1 parent 0b5722a commit 65c951b
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 79 deletions.
8 changes: 4 additions & 4 deletions hard-rebuild.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ cd ../client
rm -rf target/
sbt clean publishLocal

cd ../sql-client
rm -rf target/
sbt clean publishLocal

cd ../snapi-parser
rm -rf target/
sbt clean publishLocal
Expand All @@ -30,6 +26,10 @@ cd ../snapi-client
rm -rf target/
sbt clean publishLocal

cd ../sql-client
rm -rf target/
sbt clean publishLocal

cd ../python-client
rm -rf target/
sbt clean publishLocal
Expand Down
2 changes: 1 addition & 1 deletion snapi-frontend/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,4 @@ raw.sources.s3 {
tmp-dir = ${java.io.tmpdir}/s3

default-region = eu-west-1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -360,20 +360,16 @@ class ClientCredentials(serverAddress: URI)(implicit settings: RawSettings) exte
restClient.doJsonPost[List[String]]("2/secrets/list", ListSecretCredentials(user), withAuth = false)
}

def close(): Unit = {
restClient.close()
def getUserDb(user: AuthenticatedUser): String = {
restClient.doJsonPost[String](
"2/fdw/provision",
ProvisionFdwDbCredentials(user),
withAuth = false
)
}

def getUserDb(user: AuthenticatedUser): String = {
try {
restClient.doJsonPost[String](
"2/fdw/provision",
ProvisionFdwDbCredentials(user),
withAuth = false
)
} catch {
case ex: ClientAPIException => null
}
def close(): Unit = {
restClient.close()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ class ClientCredentialsService(implicit settings: RawSettings) extends Credentia

private val client = new ClientCredentials(serverAddress)

private val dbCacheLoader = new CacheLoader[AuthenticatedUser, String]() {
override def load(user: AuthenticatedUser): String = {
// Directly call the provisioning method on the client
logger.debug(s"Retrieving user database for $user from origin server")
client.getUserDb(user)
}
}

private val dbCache: LoadingCache[AuthenticatedUser, String] = CacheBuilder
.newBuilder()
.maximumSize(settings.getIntOpt(CACHE_FDW_SIZE).getOrElse(DEFAULT_CACHE_FDW_SIZE).toLong)
.expireAfterAccess(
settings.getIntOpt(CACHE_FDW_EXPIRY_IN_HOURS).getOrElse(DEFAULT_CACHE_FDW_EXPIRY_IN_HOURS).toLong,
TimeUnit.HOURS
)
.build(dbCacheLoader)

/** S3 buckets */

override protected def doRegisterS3Bucket(user: AuthenticatedUser, bucket: S3Bucket): Boolean = {
Expand Down Expand Up @@ -294,33 +311,14 @@ class ClientCredentialsService(implicit settings: RawSettings) extends Credentia
}
}

override def doStop(): Unit = {
client.close()
}

// Define the CacheLoader
private val dbCacheLoader = new CacheLoader[AuthenticatedUser, String]() {
override def load(user: AuthenticatedUser): String = {
// Directly call the provisioning method on the client
logger.debug(s"Retrieving user database for $user from origin server")
client.getUserDb(user)
}
}

// Initialize FDW DB LoadingCache
private val dbCache: LoadingCache[AuthenticatedUser, String] = CacheBuilder
.newBuilder()
.maximumSize(settings.getIntOpt(CACHE_FDW_SIZE).getOrElse(DEFAULT_CACHE_FDW_SIZE).toLong)
.expireAfterAccess(
settings.getIntOpt(CACHE_FDW_EXPIRY_IN_HOURS).getOrElse(DEFAULT_CACHE_FDW_EXPIRY_IN_HOURS).toLong,
TimeUnit.HOURS
)
.build(dbCacheLoader)

override def getUserDb(user: AuthenticatedUser): String = {
// Retrieve the database name from the cache, provisioning it if necessary
logger.debug(s"Retrieving user database for $user")
dbCache.get(user)
}

override def doStop(): Unit = {
client.close()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ class LocalCredentialsService extends CredentialsService {
false
}

override def doStop(): Unit = {}

override def getUserDb(user: AuthenticatedUser): String = {
"default-user-db"
}

override def doStop(): Unit = {}

}
32 changes: 19 additions & 13 deletions sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import raw.client.api._
import raw.client.sql.antlr4.{ParseProgramResult, RawSqlSyntaxAnalyzer, SqlIdnNode, SqlParamUseNode}
import raw.client.sql.metadata.UserMetadataCache
import raw.client.sql.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter}
import raw.creds.api.CredentialsServiceProvider
import raw.utils.{AuthenticatedUser, RawSettings, RawUtils}

import java.io.{IOException, OutputStream}
Expand All @@ -27,6 +28,22 @@ import scala.util.control.NonFatal
class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit protected val settings: RawSettings)
extends CompilerService {

private val credentials = CredentialsServiceProvider(maybeClassLoader)

private val connectionPool = new SqlConnectionPool(credentials)

private val metadataBrowsers = {
val loader = new CacheLoader[AuthenticatedUser, UserMetadataCache] {
override def load(user: AuthenticatedUser): UserMetadataCache =
new UserMetadataCache(user, connectionPool, settings)
}
CacheBuilder
.newBuilder()
.maximumSize(settings.getInt("raw.client.sql.metadata-cache.size"))
.expireAfterAccess(settings.getDuration("raw.client.sql.metadata-cache.duration"))
.build(loader)
}

override def language: Set[String] = Set("sql")

private def safeParse(prog: String): Either[List[ErrorMessage], ParseProgramResult] = {
Expand Down Expand Up @@ -425,21 +442,10 @@ class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit
}
}

private val connectionPool = new SqlConnectionPool(settings)
private val metadataBrowsers = {
val loader = new CacheLoader[AuthenticatedUser, UserMetadataCache] {
override def load(user: AuthenticatedUser): UserMetadataCache =
new UserMetadataCache(user, connectionPool, settings)
}
CacheBuilder
.newBuilder()
.maximumSize(settings.getInt("raw.client.sql.metadata-cache.size"))
.expireAfterAccess(settings.getDuration("raw.client.sql.metadata-cache.duration"))
.build(loader)
override def doStop(): Unit = {
credentials.stop()
}

override def doStop(): Unit = {}

private def pgRowTypeToIterableType(rowType: PostgresRowType): Either[Seq[String], RawIterableType] = {
val rowAttrTypes = rowType.columns
.map(c => SqlTypesUtils.rawTypeFromPgType(c.tipe).map(RawAttrType(c.name, _)))
Expand Down
13 changes: 5 additions & 8 deletions sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,26 @@
package raw.client.sql
import com.typesafe.scalalogging.StrictLogging
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
import raw.creds.api.CredentialsServiceProvider
import raw.creds.api.CredentialsService
import raw.utils.{AuthenticatedUser, RawSettings}

import java.sql.SQLException
import java.util.concurrent.TimeUnit
import scala.collection.mutable

class SqlConnectionPool(settings: RawSettings) extends StrictLogging {
// Make settings implicit within the local scope
implicit val implicitSettings: RawSettings = settings
class SqlConnectionPool(credentialsService: CredentialsService)(implicit settings: RawSettings) extends StrictLogging {

// one pool of connections per DB (which means per user).
// One pool of connections per DB (which means per user).
private val pools = mutable.Map.empty[String, HikariDataSource]
private val dbHost = settings.getString("raw.creds.jdbc.fdw.host")
private val dbPort = settings.getInt("raw.creds.jdbc.fdw.port")
private val readOnlyUser = settings.getString("raw.creds.jdbc.fdw.user")
private val password = settings.getString("raw.creds.jdbc.fdw.password")
private val client = CredentialsServiceProvider()

@throws[SQLException]
def getConnection(user: AuthenticatedUser): java.sql.Connection = {
val db = client.getUserDb(user) //user.uid.toString().replace("-", "_")
logger.info(s"Got database $db for user $user")
val db = credentialsService.getUserDb(user)
logger.debug(s"Got database $db for user $user")
getConnection(user, db)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@ package raw.client.sql

import org.bitbucket.inkytonik.kiama.util.Positions
import raw.client.sql.antlr4.RawSqlSyntaxAnalyzer
import raw.creds.api.CredentialsTestContext
import raw.creds.local.LocalCredentialsTestContext
import raw.utils._

class TestNamedParametersStatement extends RawTestSuite with SettingsTestContext with TrainingWheelsContext {
class TestNamedParametersStatement
extends RawTestSuite
with SettingsTestContext
with TrainingWheelsContext
with CredentialsTestContext
with LocalCredentialsTestContext {

private val database = sys.env.getOrElse("FDW_DATABASE", "raw")
private val hostname = sys.env.getOrElse("FDW_HOSTNAME", "localhost")
Expand All @@ -36,7 +43,7 @@ class TestNamedParametersStatement extends RawTestSuite with SettingsTestContext

override def beforeAll(): Unit = {
if (password != "") {
val connectionPool = new SqlConnectionPool(settings)
val connectionPool = new SqlConnectionPool(credentials)
con = connectionPool.getConnection(user)
}
super.beforeAll()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ package raw.client.sql

import raw.client.api._
import raw.creds.api.CredentialsTestContext
import raw.creds.local.LocalCredentialsService
import raw.creds.local.LocalCredentialsTestContext
import raw.utils._

import java.io.ByteArrayOutputStream
Expand All @@ -32,10 +32,10 @@ class TestSqlCompilerServiceAirports
extends RawTestSuite
with SettingsTestContext
with TrainingWheelsContext
with CredentialsTestContext {
with CredentialsTestContext
with LocalCredentialsTestContext {

private var compilerService: CompilerService = _
private var localCredentialsService: LocalCredentialsService = _

private val database = sys.env.getOrElse("FDW_DATABASE", "raw")
private val hostname = sys.env.getOrElse("FDW_HOSTNAME", "localhost")
Expand All @@ -53,25 +53,14 @@ class TestSqlCompilerServiceAirports

override def beforeAll(): Unit = {
super.beforeAll()
property("raw.creds.impl", "local")
localCredentialsService = new LocalCredentialsService()
setCredentials(localCredentialsService)
compilerService = new SqlCompilerService(None)
}

override def afterAll(): Unit = {
println(RawService.services)
if (compilerService != null) {
compilerService.stop()
compilerService = null
}
println(RawService.services.size())
if (localCredentialsService != null) {
RawUtils.withSuppressNonFatalException(localCredentialsService.stop())
localCredentialsService = null
}
println(RawService.services.size())
RawService.services.clear()
super.afterAll()
}

Expand Down

0 comments on commit 65c951b

Please sign in to comment.