Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classpath isolation for Truffle language runtime #446

Merged
merged 13 commits into from
Jun 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,10 @@ lazy val sources = (project in file("sources"))
testSettings,
libraryDependencies ++= Seq(
apacheHttpClient,
springCore,
jwtApi,
jwtImpl,
jwtCore,
springCore,
dropboxSDK,
aws,
postgresqlDeps,
Expand Down Expand Up @@ -151,8 +151,7 @@ lazy val snapiFrontend = (project in file("snapi-frontend"))
kiama,
commonsCodec,
kryo
) ++
poiDeps
)
)

val calculateClasspath = taskKey[Seq[File]]("Calculate the full classpath")
Expand All @@ -179,7 +178,7 @@ lazy val snapiTruffle = (project in file("snapi-truffle"))
commonSettings,
snapiTruffleCompileSettings,
testSettings,
libraryDependencies ++= truffleCompiler ++ scalaCompiler,
libraryDependencies ++= truffleCompiler,
calculateClasspath := {
val dependencyFiles = (Compile / dependencyClasspath).value.files
val unmanagedFiles = (Compile / unmanagedClasspath).value.files
Expand Down
20 changes: 4 additions & 16 deletions project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ object Dependencies {
"com.fasterxml.jackson.dataformat" % "jackson-dataformat-cbor" % "2.15.2",
"com.fasterxml.jackson.dataformat" % "jackson-dataformat-csv" % "2.15.2",
"com.fasterxml.jackson.dataformat" % "jackson-dataformat-yaml" % "2.15.2",
"com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.15.2-rawlabs",
"com.fasterxml.jackson.jakarta.rs" % "jackson-jakarta-rs-json-provider" % "2.15.2"
"com.fasterxml.jackson.module" %% "jackson-module-scala" % "2.15.2-rawlabs"
)

// Required while we are on Scala 2.12. It's built into Scala 2.13.
Expand All @@ -53,9 +52,7 @@ object Dependencies {

// from snapi-frontend
val kiama = "org.bitbucket.inkytonik.kiama" %% "kiama" % "2.5.1-rawlabs"

val aws =
"software.amazon.awssdk" % "s3" % "2.20.69" exclude ("commons-logging", "commons-logging") // spring.jcl is the correct replacement for this one.
val aws = "software.amazon.awssdk" % "s3" % "2.20.69" exclude ("commons-logging", "commons-logging") // we use slf4j
val woodstox = "com.fasterxml.woodstox" % "woodstox-core" % "6.5.1"
val kryo = "com.esotericsoftware" % "kryo" % "5.5.0"
val commonsLang = "org.apache.commons" % "commons-lang3" % "3.13.0"
Expand All @@ -68,27 +65,18 @@ object Dependencies {
val oracleDeps = "com.oracle.database.jdbc" % "ojdbc10" % "19.23.0.0-rawlabs"
val teradataDeps = "com.teradata.jdbc" % "terajdbc" % "20.00.00.24"
val icuDeps = "com.ibm.icu" % "icu4j" % "73.2"
val poiDeps = Seq(
"org.apache.poi" % "poi" % "5.2.3",
"org.apache.poi" % "poi-ooxml" % "5.2.3",
"org.apache.poi" % "poi-ooxml-lite" % "5.2.3"
)
val jwtApi = "io.jsonwebtoken" % "jjwt-api" % "0.11.5"
val jwtImpl = "io.jsonwebtoken" % "jjwt-impl" % "0.11.5"
val jwtCore = "com.github.jwt-scala" %% "jwt-core" % "9.4.4-rawlabs"
val springCore = "org.springframework" % "spring-core" % "5.3.13"
val springCore =
"org.springframework" % "spring-core" % "5.3.13" exclude ("org.springframework", "spring-jcl") // we use jcl-over-slf4j
val truffleCompiler = Seq(
"org.graalvm.truffle" % "truffle-api" % "23.1.0",
"org.graalvm.truffle" % "truffle-api" % "23.1.0",
"org.graalvm.truffle" % "truffle-compiler" % "23.1.0",
"org.graalvm.truffle" % "truffle-nfi" % "23.1.0",
"org.graalvm.truffle" % "truffle-nfi-libffi" % "23.1.0",
"org.graalvm.truffle" % "truffle-runtime" % "23.1.0"
)
val scalaCompiler = Seq(
"org.scala-lang" % "scala-compiler" % "2.12.18",
"org.scala-lang" % "scala-reflect" % "2.12.18"
)

// from sql-client
val hikariCP = "com.zaxxer" % "HikariCP" % "5.1.0"
Expand Down
2 changes: 2 additions & 0 deletions snapi-client/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
requires raw.client;
requires raw.snapi.frontend;

uses raw.client.api.CompilerServiceBuilder;

provides raw.client.api.CompilerServiceBuilder with
raw.client.rql2.truffle.Rql2TruffleCompilerServiceBuilder;
}
2 changes: 0 additions & 2 deletions snapi-client/src/main/resources/reference.conf

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Copyright 2024 RAW Labs S.A.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/

package raw.client.rql2.truffle

import com.typesafe.scalalogging.StrictLogging

import java.lang.module.ModuleFinder
import java.nio.file.Paths
import scala.collection.JavaConverters._

/**
* Create a custom class and module loader for the given module path.
* This used to completely isolate the Truffle runtime from the rest of the system.
*/
trait CustomClassAndModuleLoader extends StrictLogging {

/**
* Create a custom class and module loader for the given module path.
* This is fully isolated, including only the system packages and whatever JARs are in the module path passed
* as an argument.
*
* @param modulePath the path to the module
* @return a class loader that contains only the system packages and the modules in the module path
*/
def createCustomClassAndModuleLoader(modulePath: String): ClassLoader = {

// Create a custom module finder for the module path
val modulePathFinder = ModuleFinder.of(Paths.get(modulePath))

// Get all modules in the module path
val modulesFromModulePathFinder = modulePathFinder
.findAll()
.toArray
.map(_.asInstanceOf[java.lang.module.ModuleReference].descriptor().name())
.toSet

modulesFromModulePathFinder.foreach(f => logger.trace(s"JAR path module: $f"))

// Get the parent layer
val parentLayer = ModuleLayer.boot()

// Get all modules in the system module finder
val systemModuleFinder = ModuleFinder.ofSystem()
val systemModules = parentLayer.configuration().modules().asScala.map(_.reference().descriptor().name()).toSet

systemModules.foreach(f => logger.trace(s"System module: $f"))

// Combine all modules
val allModules = modulesFromModulePathFinder ++ systemModules

// Creating a configuration, starting from the parent layer, and resolving all modules.
// The parent layer is the boot layer, so it contains all the system modules.
val configuration = parentLayer
.configuration()
.resolveAndBind(
modulePathFinder,
systemModuleFinder,
allModules.asJava
)

// Get the root classloader, which is the platform classloader
val rootClassLoader = ClassLoader.getPlatformClassLoader

// Create a module layer with the custom configuration, using the parent layer and the root classloader
val controller = ModuleLayer.defineModulesWithOneLoader(configuration, List(parentLayer).asJava, rootClassLoader)
val customLayer = controller.layer()

// Because we used ModuleLayer.defineModulesWithOneLoader, all the modules share the same classloader.
// So we find the first module (whatever that is), get its classloader, and that's the classloader for all of them.
val moduleName = customLayer.modules().iterator().next().getName
customLayer.findLoader(moduleName)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,27 @@ import scala.collection.mutable
import scala.util.control.NonFatal

object Rql2TruffleCompilerService {
val language: Set[String] = Set("rql2", "rql2-truffle", "snapi")
val LANGUAGE: Set[String] = Set("rql2", "rql2-truffle", "snapi")

val JARS_PATH = "raw.client.rql2.jars-path"
}

class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean), maybeClassLoader: Option[ClassLoader])(
implicit protected val settings: RawSettings
) extends Rql2CompilerService
with CustomClassAndModuleLoader
with Rql2TypeUtils {

private val originalClassLoader = maybeClassLoader.getOrElse(Thread.currentThread().getContextClassLoader)

private val maybeTruffleClassLoader: Option[ClassLoader] = {
// If defined, contains the path used to create a classloader for the Truffle language runtime.
val maybeJarsPath = settings.getStringOpt(Rql2TruffleCompilerService.JARS_PATH)

// If the jars path is defined, create a custom class loader.
maybeJarsPath.map(jarsPath => createCustomClassAndModuleLoader(jarsPath))
}

private val (engine, initedEngine) = engineDefinition

// The default constructor allows an Engine to be specified, plus a flag to indicate whether it was created here
Expand All @@ -60,9 +73,9 @@ class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean), maybeClass
this(CompilerService.getEngine, maybeClassLoader)
}

override def language: Set[String] = Rql2TruffleCompilerService.language
override def language: Set[String] = Rql2TruffleCompilerService.LANGUAGE

private val credentials = CredentialsServiceProvider(maybeClassLoader)
private val credentials = CredentialsServiceProvider()

// Map of users to compiler context.
private val compilerContextCaches = new mutable.HashMap[AuthenticatedUser, CompilerContext]
Expand All @@ -76,13 +89,13 @@ class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean), maybeClass

private def createCompilerContext(user: AuthenticatedUser, language: String): CompilerContext = {
// Initialize source context
implicit val sourceContext = new SourceContext(user, credentials, settings, maybeClassLoader)
implicit val sourceContext = new SourceContext(user, credentials, settings, Some(originalClassLoader))

// Initialize inferrer
val inferrer = InferrerServiceProvider(maybeClassLoader)
val inferrer = InferrerServiceProvider()

// Initialize compiler context
new CompilerContext(language, user, inferrer, sourceContext, maybeClassLoader)
new CompilerContext(language, user, inferrer, sourceContext, Some(originalClassLoader))
}

private def getProgramContext(user: AuthenticatedUser, environment: ProgramEnvironment): ProgramContext = {
Expand Down Expand Up @@ -701,6 +714,14 @@ class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean), maybeClass
environment.options.get("staged-compiler").foreach { stagedCompiler =>
ctxBuilder.option("rql.staged-compiler", stagedCompiler)
}
ctxBuilder.option("rql.settings", settings.renderAsString)
// If the jars path is defined, create a custom class loader and set it as the host class loader.
maybeTruffleClassLoader.map { classLoader =>
// Set the module class loader as the Truffle runtime classloader.
// This enables the Truffle language runtime to be fully isolated from the rest of the application.
ctxBuilder.hostClassLoader(classLoader)
}

maybeOutputStream.foreach(os => ctxBuilder.out(os))
val ctx = ctxBuilder.build()
ctx
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import raw.client.api.{CompilerService, CompilerServiceBuilder}
import raw.utils.RawSettings

class Rql2TruffleCompilerServiceBuilder extends CompilerServiceBuilder {
override def language: Set[String] = Rql2TruffleCompilerService.language
override def language: Set[String] = Rql2TruffleCompilerService.LANGUAGE

override def build(maybeClassLoader: Option[ClassLoader])(implicit settings: RawSettings): CompilerService = {
new Rql2TruffleCompilerService(maybeClassLoader)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import org.scalatest.matchers.{MatchResult, Matcher}
import raw.client.api._
import raw.client.rql2.api._
import raw.compiler.base.source.{BaseProgram, Type}
import raw.compiler.rql2.api.Rql2CompilerServiceTestContext
import raw.compiler.rql2.api.{Rql2CompilerServiceTestContext, Rql2OutputTestContext}
import raw.inferrer.local.LocalInferrerTestContext
import raw.utils._

Expand All @@ -29,12 +29,13 @@ import java.nio.file.{Files, Path, StandardOpenOption}
import scala.collection.mutable
import scala.io.Source

trait CompilerTestContext
trait Rql2CompilerTestContext
extends RawTestSuite
with Matchers
with SettingsTestContext
with TrainingWheelsContext
with Rql2CompilerServiceTestContext
with Rql2OutputTestContext

// Simple inferrer
with LocalInferrerTestContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

package raw.compiler.rql2.tests.benchmark

import raw.compiler.rql2.tests.CompilerTestContext
import raw.compiler.rql2.tests.Rql2CompilerTestContext

trait BenchmarkTests extends CompilerTestContext {
trait BenchmarkTests extends Rql2CompilerTestContext {

property("raw.training-wheels", "false")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

package raw.compiler.rql2.tests.benchmark

import raw.compiler.rql2.tests.CompilerTestContext
import raw.compiler.rql2.tests.Rql2CompilerTestContext

trait StressTests extends CompilerTestContext {
trait StressTests extends Rql2CompilerTestContext {

val shouldBeExecuted = false

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
package raw.compiler.rql2.tests.builtin

import raw.compiler.utils._
import raw.compiler.rql2.tests.CompilerTestContext
import raw.compiler.rql2.tests.Rql2CompilerTestContext

import java.nio.file.Path
import java.util.Base64

trait BinaryPackageTest extends CompilerTestContext {
trait BinaryPackageTest extends Rql2CompilerTestContext {

// FIXME (msb): This should use cast to support string to binary and do .getBytes("utf-8")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

package raw.compiler.rql2.tests.builtin

import raw.compiler.rql2.tests.CompilerTestContext
import raw.compiler.rql2.tests.Rql2CompilerTestContext

trait BytePackageTest extends CompilerTestContext {
trait BytePackageTest extends Rql2CompilerTestContext {

test(""" Byte.From(1)""")(it => it should evaluateTo("1b"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
package raw.compiler.rql2.tests.builtin

import raw.compiler.utils._
import raw.compiler.rql2.tests.CompilerTestContext
import raw.compiler.rql2.tests.Rql2CompilerTestContext

trait CsvPackageTest extends CompilerTestContext {
trait CsvPackageTest extends Rql2CompilerTestContext {

val ttt = "\"\"\""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

package raw.compiler.rql2.tests.builtin

import raw.compiler.rql2.tests.CompilerTestContext
import raw.compiler.rql2.tests.Rql2CompilerTestContext

trait DatePackageTest extends CompilerTestContext {
trait DatePackageTest extends Rql2CompilerTestContext {

test("Date.Build(2022, 1, 15)") { it =>
it should typeAs("date")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

package raw.compiler.rql2.tests.builtin

import raw.compiler.rql2.tests.CompilerTestContext
import raw.compiler.rql2.tests.Rql2CompilerTestContext

trait DecimalPackageTest extends CompilerTestContext {
trait DecimalPackageTest extends Rql2CompilerTestContext {

test("""Decimal.Round(Decimal.From("1.423"), 2)""") { it =>
it should evaluateTo("""1.42q""")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@

package raw.compiler.rql2.tests.builtin

import raw.compiler.rql2.tests.CompilerTestContext
import raw.compiler.rql2.tests.Rql2CompilerTestContext

trait DoublePackageTest extends CompilerTestContext {
trait DoublePackageTest extends Rql2CompilerTestContext {

test(""" Double.From(1)""")(it => it should evaluateTo("1.0"))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
package raw.compiler.rql2.tests.builtin

import com.google.common.collect.HashMultiset
import raw.compiler.rql2.tests.CompilerTestContext
import raw.compiler.rql2.tests.Rql2CompilerTestContext

import scala.collection.JavaConverters._

trait EnvironmentPackageTest extends CompilerTestContext {
trait EnvironmentPackageTest extends Rql2CompilerTestContext {

test("""Environment.Secret("my-typo")""")(it => it should runErrorAs("could not find secret my-typo"))

Expand Down
Loading