From 65c951bad3bde5adf71e893f14447e237113bdcb Mon Sep 17 00:00:00 2001 From: Miguel Branco Date: Mon, 25 Mar 2024 15:49:15 +0100 Subject: [PATCH] Patch to RD-10677 (#389) --- hard-rebuild.sh | 8 ++-- .../src/main/resources/reference.conf | 2 +- .../raw/creds/client/ClientCredentials.scala | 20 ++++----- .../client/ClientCredentialsService.scala | 44 +++++++++---------- .../creds/local/LocalCredentialsService.scala | 5 ++- .../raw/client/sql/SqlCompilerService.scala | 32 ++++++++------ .../raw/client/sql/SqlConnectionPool.scala | 13 +++--- .../sql/TestNamedParametersStatement.scala | 11 ++++- .../sql/TestSqlCompilerServiceAirports.scala | 17 ++----- 9 files changed, 73 insertions(+), 79 deletions(-) diff --git a/hard-rebuild.sh b/hard-rebuild.sh index 701959471..a0dc46568 100755 --- a/hard-rebuild.sh +++ b/hard-rebuild.sh @@ -10,10 +10,6 @@ cd ../client rm -rf target/ sbt clean publishLocal -cd ../sql-client -rm -rf target/ -sbt clean publishLocal - cd ../snapi-parser rm -rf target/ sbt clean publishLocal @@ -30,6 +26,10 @@ cd ../snapi-client rm -rf target/ sbt clean publishLocal +cd ../sql-client +rm -rf target/ +sbt clean publishLocal + cd ../python-client rm -rf target/ sbt clean publishLocal diff --git a/snapi-frontend/src/main/resources/reference.conf b/snapi-frontend/src/main/resources/reference.conf index 2fda91f5c..0cfdc3012 100644 --- a/snapi-frontend/src/main/resources/reference.conf +++ b/snapi-frontend/src/main/resources/reference.conf @@ -117,4 +117,4 @@ raw.sources.s3 { tmp-dir = ${java.io.tmpdir}/s3 default-region = eu-west-1 -} +} \ No newline at end of file diff --git a/snapi-frontend/src/main/scala/raw/creds/client/ClientCredentials.scala b/snapi-frontend/src/main/scala/raw/creds/client/ClientCredentials.scala index 678604635..9a3a8538c 100644 --- a/snapi-frontend/src/main/scala/raw/creds/client/ClientCredentials.scala +++ b/snapi-frontend/src/main/scala/raw/creds/client/ClientCredentials.scala @@ -360,20 +360,16 @@ class ClientCredentials(serverAddress: URI)(implicit settings: RawSettings) exte restClient.doJsonPost[List[String]]("2/secrets/list", ListSecretCredentials(user), withAuth = false) } - def close(): Unit = { - restClient.close() + def getUserDb(user: AuthenticatedUser): String = { + restClient.doJsonPost[String]( + "2/fdw/provision", + ProvisionFdwDbCredentials(user), + withAuth = false + ) } - def getUserDb(user: AuthenticatedUser): String = { - try { - restClient.doJsonPost[String]( - "2/fdw/provision", - ProvisionFdwDbCredentials(user), - withAuth = false - ) - } catch { - case ex: ClientAPIException => null - } + def close(): Unit = { + restClient.close() } } diff --git a/snapi-frontend/src/main/scala/raw/creds/client/ClientCredentialsService.scala b/snapi-frontend/src/main/scala/raw/creds/client/ClientCredentialsService.scala index 6d1fe530e..375254b0b 100644 --- a/snapi-frontend/src/main/scala/raw/creds/client/ClientCredentialsService.scala +++ b/snapi-frontend/src/main/scala/raw/creds/client/ClientCredentialsService.scala @@ -41,6 +41,23 @@ class ClientCredentialsService(implicit settings: RawSettings) extends Credentia private val client = new ClientCredentials(serverAddress) + private val dbCacheLoader = new CacheLoader[AuthenticatedUser, String]() { + override def load(user: AuthenticatedUser): String = { + // Directly call the provisioning method on the client + logger.debug(s"Retrieving user database for $user from origin server") + client.getUserDb(user) + } + } + + private val dbCache: LoadingCache[AuthenticatedUser, String] = CacheBuilder + .newBuilder() + .maximumSize(settings.getIntOpt(CACHE_FDW_SIZE).getOrElse(DEFAULT_CACHE_FDW_SIZE).toLong) + .expireAfterAccess( + settings.getIntOpt(CACHE_FDW_EXPIRY_IN_HOURS).getOrElse(DEFAULT_CACHE_FDW_EXPIRY_IN_HOURS).toLong, + TimeUnit.HOURS + ) + .build(dbCacheLoader) + /** S3 buckets */ override protected def doRegisterS3Bucket(user: AuthenticatedUser, bucket: S3Bucket): Boolean = { @@ -294,33 +311,14 @@ class ClientCredentialsService(implicit settings: RawSettings) extends Credentia } } - override def doStop(): Unit = { - client.close() - } - - // Define the CacheLoader - private val dbCacheLoader = new CacheLoader[AuthenticatedUser, String]() { - override def load(user: AuthenticatedUser): String = { - // Directly call the provisioning method on the client - logger.debug(s"Retrieving user database for $user from origin server") - client.getUserDb(user) - } - } - - // Initialize FDW DB LoadingCache - private val dbCache: LoadingCache[AuthenticatedUser, String] = CacheBuilder - .newBuilder() - .maximumSize(settings.getIntOpt(CACHE_FDW_SIZE).getOrElse(DEFAULT_CACHE_FDW_SIZE).toLong) - .expireAfterAccess( - settings.getIntOpt(CACHE_FDW_EXPIRY_IN_HOURS).getOrElse(DEFAULT_CACHE_FDW_EXPIRY_IN_HOURS).toLong, - TimeUnit.HOURS - ) - .build(dbCacheLoader) - override def getUserDb(user: AuthenticatedUser): String = { // Retrieve the database name from the cache, provisioning it if necessary logger.debug(s"Retrieving user database for $user") dbCache.get(user) } + override def doStop(): Unit = { + client.close() + } + } diff --git a/snapi-frontend/src/main/scala/raw/creds/local/LocalCredentialsService.scala b/snapi-frontend/src/main/scala/raw/creds/local/LocalCredentialsService.scala index 6d45e8add..fee6e5add 100644 --- a/snapi-frontend/src/main/scala/raw/creds/local/LocalCredentialsService.scala +++ b/snapi-frontend/src/main/scala/raw/creds/local/LocalCredentialsService.scala @@ -145,9 +145,10 @@ class LocalCredentialsService extends CredentialsService { false } - override def doStop(): Unit = {} - override def getUserDb(user: AuthenticatedUser): String = { "default-user-db" } + + override def doStop(): Unit = {} + } diff --git a/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala b/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala index b57565fc9..d38c45f1e 100644 --- a/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala +++ b/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala @@ -18,6 +18,7 @@ import raw.client.api._ import raw.client.sql.antlr4.{ParseProgramResult, RawSqlSyntaxAnalyzer, SqlIdnNode, SqlParamUseNode} import raw.client.sql.metadata.UserMetadataCache import raw.client.sql.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter} +import raw.creds.api.CredentialsServiceProvider import raw.utils.{AuthenticatedUser, RawSettings, RawUtils} import java.io.{IOException, OutputStream} @@ -27,6 +28,22 @@ import scala.util.control.NonFatal class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit protected val settings: RawSettings) extends CompilerService { + private val credentials = CredentialsServiceProvider(maybeClassLoader) + + private val connectionPool = new SqlConnectionPool(credentials) + + private val metadataBrowsers = { + val loader = new CacheLoader[AuthenticatedUser, UserMetadataCache] { + override def load(user: AuthenticatedUser): UserMetadataCache = + new UserMetadataCache(user, connectionPool, settings) + } + CacheBuilder + .newBuilder() + .maximumSize(settings.getInt("raw.client.sql.metadata-cache.size")) + .expireAfterAccess(settings.getDuration("raw.client.sql.metadata-cache.duration")) + .build(loader) + } + override def language: Set[String] = Set("sql") private def safeParse(prog: String): Either[List[ErrorMessage], ParseProgramResult] = { @@ -425,21 +442,10 @@ class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit } } - private val connectionPool = new SqlConnectionPool(settings) - private val metadataBrowsers = { - val loader = new CacheLoader[AuthenticatedUser, UserMetadataCache] { - override def load(user: AuthenticatedUser): UserMetadataCache = - new UserMetadataCache(user, connectionPool, settings) - } - CacheBuilder - .newBuilder() - .maximumSize(settings.getInt("raw.client.sql.metadata-cache.size")) - .expireAfterAccess(settings.getDuration("raw.client.sql.metadata-cache.duration")) - .build(loader) + override def doStop(): Unit = { + credentials.stop() } - override def doStop(): Unit = {} - private def pgRowTypeToIterableType(rowType: PostgresRowType): Either[Seq[String], RawIterableType] = { val rowAttrTypes = rowType.columns .map(c => SqlTypesUtils.rawTypeFromPgType(c.tipe).map(RawAttrType(c.name, _))) diff --git a/sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala b/sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala index 13735752e..26ca7b89d 100644 --- a/sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala +++ b/sql-client/src/main/scala/raw/client/sql/SqlConnectionPool.scala @@ -13,29 +13,26 @@ package raw.client.sql import com.typesafe.scalalogging.StrictLogging import com.zaxxer.hikari.{HikariConfig, HikariDataSource} -import raw.creds.api.CredentialsServiceProvider +import raw.creds.api.CredentialsService import raw.utils.{AuthenticatedUser, RawSettings} import java.sql.SQLException import java.util.concurrent.TimeUnit import scala.collection.mutable -class SqlConnectionPool(settings: RawSettings) extends StrictLogging { - // Make settings implicit within the local scope - implicit val implicitSettings: RawSettings = settings +class SqlConnectionPool(credentialsService: CredentialsService)(implicit settings: RawSettings) extends StrictLogging { - // one pool of connections per DB (which means per user). + // One pool of connections per DB (which means per user). private val pools = mutable.Map.empty[String, HikariDataSource] private val dbHost = settings.getString("raw.creds.jdbc.fdw.host") private val dbPort = settings.getInt("raw.creds.jdbc.fdw.port") private val readOnlyUser = settings.getString("raw.creds.jdbc.fdw.user") private val password = settings.getString("raw.creds.jdbc.fdw.password") - private val client = CredentialsServiceProvider() @throws[SQLException] def getConnection(user: AuthenticatedUser): java.sql.Connection = { - val db = client.getUserDb(user) //user.uid.toString().replace("-", "_") - logger.info(s"Got database $db for user $user") + val db = credentialsService.getUserDb(user) + logger.debug(s"Got database $db for user $user") getConnection(user, db) } diff --git a/sql-client/src/test/scala/raw/client/sql/TestNamedParametersStatement.scala b/sql-client/src/test/scala/raw/client/sql/TestNamedParametersStatement.scala index 8080abb4f..e523d22a8 100644 --- a/sql-client/src/test/scala/raw/client/sql/TestNamedParametersStatement.scala +++ b/sql-client/src/test/scala/raw/client/sql/TestNamedParametersStatement.scala @@ -14,9 +14,16 @@ package raw.client.sql import org.bitbucket.inkytonik.kiama.util.Positions import raw.client.sql.antlr4.RawSqlSyntaxAnalyzer +import raw.creds.api.CredentialsTestContext +import raw.creds.local.LocalCredentialsTestContext import raw.utils._ -class TestNamedParametersStatement extends RawTestSuite with SettingsTestContext with TrainingWheelsContext { +class TestNamedParametersStatement + extends RawTestSuite + with SettingsTestContext + with TrainingWheelsContext + with CredentialsTestContext + with LocalCredentialsTestContext { private val database = sys.env.getOrElse("FDW_DATABASE", "raw") private val hostname = sys.env.getOrElse("FDW_HOSTNAME", "localhost") @@ -36,7 +43,7 @@ class TestNamedParametersStatement extends RawTestSuite with SettingsTestContext override def beforeAll(): Unit = { if (password != "") { - val connectionPool = new SqlConnectionPool(settings) + val connectionPool = new SqlConnectionPool(credentials) con = connectionPool.getConnection(user) } super.beforeAll() diff --git a/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala b/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala index ddb8ee402..8fd6e79a9 100644 --- a/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala +++ b/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala @@ -23,7 +23,7 @@ package raw.client.sql import raw.client.api._ import raw.creds.api.CredentialsTestContext -import raw.creds.local.LocalCredentialsService +import raw.creds.local.LocalCredentialsTestContext import raw.utils._ import java.io.ByteArrayOutputStream @@ -32,10 +32,10 @@ class TestSqlCompilerServiceAirports extends RawTestSuite with SettingsTestContext with TrainingWheelsContext - with CredentialsTestContext { + with CredentialsTestContext + with LocalCredentialsTestContext { private var compilerService: CompilerService = _ - private var localCredentialsService: LocalCredentialsService = _ private val database = sys.env.getOrElse("FDW_DATABASE", "raw") private val hostname = sys.env.getOrElse("FDW_HOSTNAME", "localhost") @@ -53,25 +53,14 @@ class TestSqlCompilerServiceAirports override def beforeAll(): Unit = { super.beforeAll() - property("raw.creds.impl", "local") - localCredentialsService = new LocalCredentialsService() - setCredentials(localCredentialsService) compilerService = new SqlCompilerService(None) } override def afterAll(): Unit = { - println(RawService.services) if (compilerService != null) { compilerService.stop() compilerService = null } - println(RawService.services.size()) - if (localCredentialsService != null) { - RawUtils.withSuppressNonFatalException(localCredentialsService.stop()) - localCredentialsService = null - } - println(RawService.services.size()) - RawService.services.clear() super.afterAll() }