Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Patch to RD-10677 #389

Merged
merged 1 commit into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions hard-rebuild.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@ cd ../client
rm -rf target/
sbt clean publishLocal

cd ../sql-client
rm -rf target/
sbt clean publishLocal

cd ../snapi-parser
rm -rf target/
sbt clean publishLocal
Expand All @@ -30,6 +26,10 @@ cd ../snapi-client
rm -rf target/
sbt clean publishLocal

cd ../sql-client
rm -rf target/
sbt clean publishLocal

cd ../python-client
rm -rf target/
sbt clean publishLocal
Expand Down
2 changes: 1 addition & 1 deletion snapi-frontend/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
Expand Up @@ -117,4 +117,4 @@ raw.sources.s3 {
tmp-dir = ${java.io.tmpdir}/s3

default-region = eu-west-1
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -360,20 +360,16 @@ class ClientCredentials(serverAddress: URI)(implicit settings: RawSettings) exte
restClient.doJsonPost[List[String]]("2/secrets/list", ListSecretCredentials(user), withAuth = false)
}

def close(): Unit = {
restClient.close()
def getUserDb(user: AuthenticatedUser): String = {
restClient.doJsonPost[String](
"2/fdw/provision",
ProvisionFdwDbCredentials(user),
withAuth = false
)
}

def getUserDb(user: AuthenticatedUser): String = {
try {
restClient.doJsonPost[String](
"2/fdw/provision",
ProvisionFdwDbCredentials(user),
withAuth = false
)
} catch {
case ex: ClientAPIException => null
}
def close(): Unit = {
restClient.close()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,23 @@ class ClientCredentialsService(implicit settings: RawSettings) extends Credentia

private val client = new ClientCredentials(serverAddress)

private val dbCacheLoader = new CacheLoader[AuthenticatedUser, String]() {
override def load(user: AuthenticatedUser): String = {
// Directly call the provisioning method on the client
logger.debug(s"Retrieving user database for $user from origin server")
client.getUserDb(user)
}
}

private val dbCache: LoadingCache[AuthenticatedUser, String] = CacheBuilder
.newBuilder()
.maximumSize(settings.getIntOpt(CACHE_FDW_SIZE).getOrElse(DEFAULT_CACHE_FDW_SIZE).toLong)
.expireAfterAccess(
settings.getIntOpt(CACHE_FDW_EXPIRY_IN_HOURS).getOrElse(DEFAULT_CACHE_FDW_EXPIRY_IN_HOURS).toLong,
TimeUnit.HOURS
)
.build(dbCacheLoader)

/** S3 buckets */

override protected def doRegisterS3Bucket(user: AuthenticatedUser, bucket: S3Bucket): Boolean = {
Expand Down Expand Up @@ -294,33 +311,14 @@ class ClientCredentialsService(implicit settings: RawSettings) extends Credentia
}
}

override def doStop(): Unit = {
client.close()
}

// Define the CacheLoader
private val dbCacheLoader = new CacheLoader[AuthenticatedUser, String]() {
override def load(user: AuthenticatedUser): String = {
// Directly call the provisioning method on the client
logger.debug(s"Retrieving user database for $user from origin server")
client.getUserDb(user)
}
}

// Initialize FDW DB LoadingCache
private val dbCache: LoadingCache[AuthenticatedUser, String] = CacheBuilder
.newBuilder()
.maximumSize(settings.getIntOpt(CACHE_FDW_SIZE).getOrElse(DEFAULT_CACHE_FDW_SIZE).toLong)
.expireAfterAccess(
settings.getIntOpt(CACHE_FDW_EXPIRY_IN_HOURS).getOrElse(DEFAULT_CACHE_FDW_EXPIRY_IN_HOURS).toLong,
TimeUnit.HOURS
)
.build(dbCacheLoader)

override def getUserDb(user: AuthenticatedUser): String = {
// Retrieve the database name from the cache, provisioning it if necessary
logger.debug(s"Retrieving user database for $user")
dbCache.get(user)
}

override def doStop(): Unit = {
client.close()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,10 @@ class LocalCredentialsService extends CredentialsService {
false
}

override def doStop(): Unit = {}

override def getUserDb(user: AuthenticatedUser): String = {
"default-user-db"
}

override def doStop(): Unit = {}

}
32 changes: 19 additions & 13 deletions sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import raw.client.api._
import raw.client.sql.antlr4.{ParseProgramResult, RawSqlSyntaxAnalyzer, SqlIdnNode, SqlParamUseNode}
import raw.client.sql.metadata.UserMetadataCache
import raw.client.sql.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter}
import raw.creds.api.CredentialsServiceProvider
import raw.utils.{AuthenticatedUser, RawSettings, RawUtils}

import java.io.{IOException, OutputStream}
Expand All @@ -27,6 +28,22 @@ import scala.util.control.NonFatal
class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit protected val settings: RawSettings)
extends CompilerService {

private val credentials = CredentialsServiceProvider(maybeClassLoader)

private val connectionPool = new SqlConnectionPool(credentials)

private val metadataBrowsers = {
val loader = new CacheLoader[AuthenticatedUser, UserMetadataCache] {
override def load(user: AuthenticatedUser): UserMetadataCache =
new UserMetadataCache(user, connectionPool, settings)
}
CacheBuilder
.newBuilder()
.maximumSize(settings.getInt("raw.client.sql.metadata-cache.size"))
.expireAfterAccess(settings.getDuration("raw.client.sql.metadata-cache.duration"))
.build(loader)
}

override def language: Set[String] = Set("sql")

private def safeParse(prog: String): Either[List[ErrorMessage], ParseProgramResult] = {
Expand Down Expand Up @@ -425,21 +442,10 @@ class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit
}
}

private val connectionPool = new SqlConnectionPool(settings)
private val metadataBrowsers = {
val loader = new CacheLoader[AuthenticatedUser, UserMetadataCache] {
override def load(user: AuthenticatedUser): UserMetadataCache =
new UserMetadataCache(user, connectionPool, settings)
}
CacheBuilder
.newBuilder()
.maximumSize(settings.getInt("raw.client.sql.metadata-cache.size"))
.expireAfterAccess(settings.getDuration("raw.client.sql.metadata-cache.duration"))
.build(loader)
override def doStop(): Unit = {
credentials.stop()
}

override def doStop(): Unit = {}

private def pgRowTypeToIterableType(rowType: PostgresRowType): Either[Seq[String], RawIterableType] = {
val rowAttrTypes = rowType.columns
.map(c => SqlTypesUtils.rawTypeFromPgType(c.tipe).map(RawAttrType(c.name, _)))
Expand Down
13 changes: 5 additions & 8 deletions sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,26 @@
package raw.client.sql
import com.typesafe.scalalogging.StrictLogging
import com.zaxxer.hikari.{HikariConfig, HikariDataSource}
import raw.creds.api.CredentialsServiceProvider
import raw.creds.api.CredentialsService
import raw.utils.{AuthenticatedUser, RawSettings}

import java.sql.SQLException
import java.util.concurrent.TimeUnit
import scala.collection.mutable

class SqlConnectionPool(settings: RawSettings) extends StrictLogging {
// Make settings implicit within the local scope
implicit val implicitSettings: RawSettings = settings
class SqlConnectionPool(credentialsService: CredentialsService)(implicit settings: RawSettings) extends StrictLogging {

// one pool of connections per DB (which means per user).
// One pool of connections per DB (which means per user).
private val pools = mutable.Map.empty[String, HikariDataSource]
private val dbHost = settings.getString("raw.creds.jdbc.fdw.host")
private val dbPort = settings.getInt("raw.creds.jdbc.fdw.port")
private val readOnlyUser = settings.getString("raw.creds.jdbc.fdw.user")
private val password = settings.getString("raw.creds.jdbc.fdw.password")
private val client = CredentialsServiceProvider()

@throws[SQLException]
def getConnection(user: AuthenticatedUser): java.sql.Connection = {
val db = client.getUserDb(user) //user.uid.toString().replace("-", "_")
logger.info(s"Got database $db for user $user")
val db = credentialsService.getUserDb(user)
logger.debug(s"Got database $db for user $user")
getConnection(user, db)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,16 @@ package raw.client.sql

import org.bitbucket.inkytonik.kiama.util.Positions
import raw.client.sql.antlr4.RawSqlSyntaxAnalyzer
import raw.creds.api.CredentialsTestContext
import raw.creds.local.LocalCredentialsTestContext
import raw.utils._

class TestNamedParametersStatement extends RawTestSuite with SettingsTestContext with TrainingWheelsContext {
class TestNamedParametersStatement
extends RawTestSuite
with SettingsTestContext
with TrainingWheelsContext
with CredentialsTestContext
with LocalCredentialsTestContext {

private val database = sys.env.getOrElse("FDW_DATABASE", "raw")
private val hostname = sys.env.getOrElse("FDW_HOSTNAME", "localhost")
Expand All @@ -36,7 +43,7 @@ class TestNamedParametersStatement extends RawTestSuite with SettingsTestContext

override def beforeAll(): Unit = {
if (password != "") {
val connectionPool = new SqlConnectionPool(settings)
val connectionPool = new SqlConnectionPool(credentials)
con = connectionPool.getConnection(user)
}
super.beforeAll()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ package raw.client.sql

import raw.client.api._
import raw.creds.api.CredentialsTestContext
import raw.creds.local.LocalCredentialsService
import raw.creds.local.LocalCredentialsTestContext
import raw.utils._

import java.io.ByteArrayOutputStream
Expand All @@ -32,10 +32,10 @@ class TestSqlCompilerServiceAirports
extends RawTestSuite
with SettingsTestContext
with TrainingWheelsContext
with CredentialsTestContext {
with CredentialsTestContext
with LocalCredentialsTestContext {

private var compilerService: CompilerService = _
private var localCredentialsService: LocalCredentialsService = _

private val database = sys.env.getOrElse("FDW_DATABASE", "raw")
private val hostname = sys.env.getOrElse("FDW_HOSTNAME", "localhost")
Expand All @@ -53,25 +53,14 @@ class TestSqlCompilerServiceAirports

override def beforeAll(): Unit = {
super.beforeAll()
property("raw.creds.impl", "local")
localCredentialsService = new LocalCredentialsService()
setCredentials(localCredentialsService)
compilerService = new SqlCompilerService(None)
}

override def afterAll(): Unit = {
println(RawService.services)
if (compilerService != null) {
compilerService.stop()
compilerService = null
}
println(RawService.services.size())
if (localCredentialsService != null) {
RawUtils.withSuppressNonFatalException(localCredentialsService.stop())
localCredentialsService = null
}
println(RawService.services.size())
RawService.services.clear()
super.afterAll()
}

Expand Down
Loading