Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into RD-10468
Browse files Browse the repository at this point in the history
  • Loading branch information
torcato committed Jan 23, 2024
2 parents 006d386 + 8c9208d commit 113a67d
Show file tree
Hide file tree
Showing 13 changed files with 387 additions and 45 deletions.
2 changes: 1 addition & 1 deletion hard-rebuild.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ sbt clean publishLocal

cd ../snapi-truffle
rm -rf target/
sbt clean publishLocal
sbt clean runJavaAnnotationProcessor publishLocal

cd ../snapi-client
rm -rf target/
Expand Down
1 change: 1 addition & 0 deletions snapi-frontend/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
exports raw.compiler.rql2.errors;
exports raw.compiler.rql2.lsp;
exports raw.compiler.rql2.source;
exports raw.compiler.rql2.antlr4;
exports raw.compiler.utils;
exports raw.inferrer.api;
exports raw.inferrer.local;
Expand Down
4 changes: 1 addition & 3 deletions snapi-parser/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,7 @@ Compile / packageSrc / publishArtifact := true

// Dependencies
libraryDependencies ++= Seq(
// We depend directly on the Truffle DSL processor to use their Antlr4.
// If we'd use ours, they would conflict as Truffle DSL package defines the org.antlr4 package.
"org.graalvm.truffle" % "truffle-dsl-processor" % "23.1.0"
"org.antlr" % "antlr4-runtime" % "4.12.0"
)

val generateParser = taskKey[Unit]("Generated antlr4 base parser and lexer")
Expand Down
4 changes: 2 additions & 2 deletions snapi-parser/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
*/

module raw.snapi.parser {
requires truffle.dsl.processor;

exports raw.compiler.rql2.generated;

requires org.antlr.antlr4.runtime;
}
83 changes: 76 additions & 7 deletions snapi-truffle/build.sbt
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import de.heikoseeberger.sbtheader.HeaderPlugin.autoImport._
import de.heikoseeberger.sbtheader.HeaderPlugin.autoImport.*
import sbt.Keys.*
import sbt.*

import sbt.Keys._
import sbt._
import Dependencies.*

import java.time.Year
import scala.sys.process.Process

import Dependencies._
import com.jsuereth.sbtpgp.PgpKeys.{publishSigned}

ThisBuild / sonatypeCredentialHost := "s01.oss.sonatype.org"

Expand Down Expand Up @@ -87,6 +88,12 @@ Test / doc / sources := {
(Compile / doc / sources).value.filterNot(_.getName.endsWith(".java"))
}

Compile / unmanagedSourceDirectories += baseDirectory.value / "target" / "java-processed-sources"

Compile / unmanagedResourceDirectories += baseDirectory.value / "target" / "java-processed-sources" / "META-INF"

Compile / resourceDirectories += baseDirectory.value / "target" / "java-processed-sources" / "META-INF"

// Add all the classpath to the module path.
Compile / javacOptions ++= Seq(
"--module-path",
Expand Down Expand Up @@ -121,15 +128,73 @@ Test / javaOptions ++= Seq(
"-Dpolyglotimpl.CompilationFailureAction=Throw"
)

val annotationProcessors = Seq(
"com.oracle.truffle.dsl.processor.TruffleProcessor",
"com.oracle.truffle.dsl.processor.verify.VerifyTruffleProcessor",
"com.oracle.truffle.dsl.processor.LanguageRegistrationProcessor",
"com.oracle.truffle.dsl.processor.InstrumentRegistrationProcessor",
"com.oracle.truffle.dsl.processor.OptionalResourceRegistrationProcessor",
"com.oracle.truffle.dsl.processor.InstrumentableProcessor",
"com.oracle.truffle.dsl.processor.verify.VerifyCompilationFinalProcessor",
"com.oracle.truffle.dsl.processor.OptionProcessor"
).mkString(",")

val calculateClasspath = taskKey[Seq[File]]("Calculate the full classpath")

calculateClasspath := {
val dependencyFiles = (Compile / dependencyClasspath).value.files
val unmanagedFiles = (Compile / unmanagedClasspath).value.files
val classesDir = (Compile / classDirectory).value

dependencyFiles ++ unmanagedFiles ++ Seq(classesDir)
}

val runJavaAnnotationProcessor = taskKey[Unit]("Runs the Java annotation processor")

runJavaAnnotationProcessor := {
println("Running Java annotation processor")

val annotationProcessorJar = baseDirectory.value / "truffle-dsl-processor-23.1.0.jar"

val javaSources = baseDirectory.value / "src" / "main" / "java"
val targetDir = baseDirectory.value / "target" / "java-processed-sources"

val projectClasspath = calculateClasspath.value.mkString(":")

val javacOptions = Seq(
"javac",
"-source",
"21",
"-target",
"21",
"-d",
targetDir.getAbsolutePath,
"--module-path",
projectClasspath,
"-cp",
annotationProcessorJar.getAbsolutePath,
"-processor",
annotationProcessors,
"-proc:only"
) ++ (javaSources ** "*.java").get.map(_.absolutePath)

// Create the target directory if it doesn't exist
targetDir.mkdirs()

// Execute the Java compiler
val result = Process(javacOptions).!
if (result != 0) {
throw new RuntimeException("Java annotation processing failed.")
}
}

// Add dependency resolvers
resolvers += Resolver.mavenLocal
resolvers += Resolver.sonatypeRepo("releases")

// Publish settings
Test / publishArtifact := true
Compile / packageSrc / publishArtifact := true
// When doing publishLocal, also publish to the local maven repository.
publishLocal := (publishLocal dependsOn publishM2).value

// Dependencies
libraryDependencies ++= Seq(
Expand All @@ -155,3 +220,7 @@ outputVersion := {
}

Compile / compile := ((Compile / compile) dependsOn outputVersion).value

publishLocal := (publishLocal dependsOn Def.sequential(runJavaAnnotationProcessor, outputVersion, publishM2)).value
publish := (publish dependsOn Def.sequential(runJavaAnnotationProcessor, outputVersion)).value
publishSigned := (publishSigned dependsOn Def.sequential(runJavaAnnotationProcessor, outputVersion)).value
1 change: 0 additions & 1 deletion snapi-truffle/project/Dependencies.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ object Dependencies {

val truffleCompiler = Seq(
"org.graalvm.truffle" % "truffle-api" % "23.1.0",
"org.graalvm.truffle" % "truffle-dsl-processor" % "23.1.0" % Provided,
"org.graalvm.truffle" % "truffle-api" % "23.1.0",
"org.graalvm.truffle" % "truffle-compiler" % "23.1.0",
"org.graalvm.truffle" % "truffle-nfi" % "23.1.0",
Expand Down
Binary file added snapi-truffle/truffle-dsl-processor-23.1.0.jar
Binary file not shown.
5 changes: 1 addition & 4 deletions sql-client/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,6 @@ publishLocal := (publishLocal dependsOn Def.sequential(outputVersion, publishM2)
libraryDependencies ++= Seq(
rawClient % "compile->compile;test->test",
postgresqlDeps,
hikariCP,
// pretending a dependency on 'python' in order to have a truffle language, otherwise one cannot use the polyglot API
"org.graalvm.polyglot" % "python" % "23.1.0" % Provided
)
hikariCP)

Compile / packageBin / packageOptions += Package.ManifestAttributes("Automatic-Module-Name" -> "raw.sql.client")
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ class NamedParametersPreparedStatement(conn: Connection, code: String) extends S
val typeInfo = {
val tipe = metadata.getColumnType(i)
val nullable = metadata.isNullable(i) == ResultSetMetaData.columnNullable
SqlTypesUtils.rawTypeFromJdbc(tipe).right.map(_.cloneWithFlags(nullable, false))
SqlTypesUtils.rawTypeFromJdbc(tipe).right.map {
case t: RawAnyType => t
case t: RawType => t.cloneWithFlags(nullable, false)
}
}
typeInfo.right.map(t => RawAttrType(name, t))
}
Expand Down
46 changes: 22 additions & 24 deletions sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@
package raw.client.sql

import com.google.common.cache.{CacheBuilder, CacheLoader}
import org.graalvm.polyglot.{Context, HostAccess, Value}
import raw.client.api._
import raw.client.sql.SqlCodeUtils._
import raw.client.writers.{TypedPolyglotCsvWriter, TypedPolyglotJsonWriter}
import raw.client.sql.writers.{TypedResultSetCsvWriter, TypedResultSetJsonWriter}
import raw.utils.{AuthenticatedUser, RawSettings, RawUtils}

import java.io.{IOException, OutputStream}
import java.sql.{SQLException, SQLTimeoutException}
import java.sql.{ResultSet, SQLException, SQLTimeoutException}
import scala.collection.mutable

class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit protected val settings: RawSettings)
Expand Down Expand Up @@ -100,23 +99,22 @@ class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit
val conn = connectionPool.getConnection(environment.user)
try {
val pstmt = new NamedParametersPreparedStatement(conn, source)
val result = pstmt.queryMetadata match {
case Right(info) =>
try {
val tipe = info.outputType
val access = HostAccess.newBuilder().allowMapAccess(true).allowIteratorAccess(true).build()
val ctx = Context.newBuilder().allowHostAccess(access).build()
environment.maybeArguments.foreach(array => setParams(pstmt, array))
val r = pstmt.executeQuery()
val v = ctx.asValue(new ResultSetIterator(r, ctx))
render(environment, tipe, v, outputStream)
} catch {
case e: SQLException => ExecutionRuntimeFailure(e.getMessage)
}
case Left(errors) => ExecutionValidationFailure(errors)
try {
pstmt.queryMetadata match {
case Right(info) =>
try {
val tipe = info.outputType
environment.maybeArguments.foreach(array => setParams(pstmt, array))
val r = pstmt.executeQuery()
render(environment, tipe, r, outputStream)
} catch {
case e: SQLException => ExecutionRuntimeFailure(e.getMessage)
}
case Left(errors) => ExecutionValidationFailure(errors)
}
} finally {
RawUtils.withSuppressNonFatalException(pstmt.close())
}
pstmt.close()
result
} catch {
case e: SQLException => ExecutionValidationFailure(mkError(source, e))
} finally {
Expand All @@ -131,22 +129,22 @@ class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit
private def render(
environment: ProgramEnvironment,
tipe: RawType,
v: Value,
v: ResultSet,
outputStream: OutputStream
): ExecutionResponse = {
environment.options
.get("output-format")
.map(_.toLowerCase) match {
case Some("csv") =>
if (!TypedPolyglotCsvWriter.outputWriteSupport(tipe)) {
if (!TypedResultSetCsvWriter.outputWriteSupport(tipe)) {
ExecutionRuntimeFailure("unsupported type")
}
val windowsLineEnding = environment.options.get("windows-line-ending") match {
case Some("true") => true
case _ => false //settings.config.getBoolean("raw.compiler.windows-line-ending")
}
val lineSeparator = if (windowsLineEnding) "\r\n" else "\n"
val csvWriter = new TypedPolyglotCsvWriter(outputStream, lineSeparator)
val csvWriter = new TypedResultSetCsvWriter(outputStream, lineSeparator)
try {
csvWriter.write(v, tipe)
ExecutionSuccess
Expand All @@ -156,10 +154,10 @@ class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit
RawUtils.withSuppressNonFatalException(csvWriter.close())
}
case Some("json") =>
if (!TypedPolyglotJsonWriter.outputWriteSupport(tipe)) {
if (!TypedResultSetJsonWriter.outputWriteSupport(tipe)) {
ExecutionRuntimeFailure("unsupported type")
}
val w = new TypedPolyglotJsonWriter(outputStream)
val w = new TypedResultSetJsonWriter(outputStream)
try {
w.write(v, tipe)
ExecutionSuccess
Expand Down
5 changes: 3 additions & 2 deletions sql-client/src/main/scala/raw/client/sql/SqlTypesUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ package raw.client.sql

import com.typesafe.scalalogging.StrictLogging
import raw.client.api.{
RawAnyType,
RawBoolType,
RawByteType,
RawDateType,
Expand All @@ -37,7 +38,7 @@ object SqlTypesUtils extends StrictLogging {

// a mapping from JDBC types to a RawType. We also store the name of the JDBC type for error reporting.
private val typeMap: Map[Int, SqlType] = Map(
java.sql.Types.BIT -> SqlType("BIT", None),
java.sql.Types.BIT -> SqlType("BIT", Some(RawBoolType(false, false))),
java.sql.Types.TINYINT -> SqlType("TINYINT", Some(RawByteType(false, false))),
java.sql.Types.SMALLINT -> SqlType("SMALLINT", Some(RawShortType(false, false))),
java.sql.Types.INTEGER -> SqlType("INTEGER", Some(RawIntType(false, false))),
Expand All @@ -57,7 +58,7 @@ object SqlTypesUtils extends StrictLogging {
java.sql.Types.VARBINARY -> SqlType("VARBINARY", None),
java.sql.Types.LONGVARBINARY -> SqlType("LONGVARBINARY", None),
java.sql.Types.NULL -> SqlType("NULL", None),
java.sql.Types.OTHER -> SqlType("OTHER", None),
java.sql.Types.OTHER -> SqlType("OTHER", Some(RawAnyType())),
java.sql.Types.JAVA_OBJECT -> SqlType("JAVA_OBJECT", None),
java.sql.Types.DISTINCT -> SqlType("DISTINCT", None),
java.sql.Types.STRUCT -> SqlType("STRUCT", None),
Expand Down
Loading

0 comments on commit 113a67d

Please sign in to comment.