diff --git a/build.sbt b/build.sbt index b0ab45a3a..19ad0148e 100644 --- a/build.sbt +++ b/build.sbt @@ -38,7 +38,8 @@ lazy val root = (project in file(".")) snapiTruffle, snapiClient, sqlParser, - sqlClient + sqlClient, + jinjaSqlClient ) .settings( commonSettings, @@ -282,5 +283,18 @@ lazy val pythonClient = (project in file("python-client")) missingInterpolatorCompileSettings, testSettings, Compile / packageBin / packageOptions += Package.ManifestAttributes("Automatic-Module-Name" -> "raw.python.client"), - libraryDependencies += "org.graalvm.polyglot" % "python" % "23.1.0" % Provided + libraryDependencies += trufflePython ) + +lazy val jinjaSqlClient = (project in file("jinja-sql-client")) + .dependsOn( + client % "compile->compile;test->test", + sqlClient % "compile->compile;test->test", + ) + .settings( + commonSettings, + missingInterpolatorCompileSettings, + testSettings, + libraryDependencies += trufflePython + ) + diff --git a/client/src/main/scala/raw/client/api/CompilerService.scala b/client/src/main/scala/raw/client/api/CompilerService.scala index ab172b32f..f75be36f5 100644 --- a/client/src/main/scala/raw/client/api/CompilerService.scala +++ b/client/src/main/scala/raw/client/api/CompilerService.scala @@ -63,7 +63,11 @@ object CompilerService { // options.put("engine.CompilationFailureAction", "Diagnose") // options.put("compiler.LogInlinedTargets", "true") } - val engine = Engine.newBuilder().allowExperimentalOptions(true).options(options).build() + val engine = Engine + .newBuilder() + .allowExperimentalOptions(true) + .options(options) + .build() enginesCache.put(settings, engine) (engine, true) } diff --git a/jinja-sql-client/project/build.properties b/jinja-sql-client/project/build.properties new file mode 100644 index 000000000..49214c4bb --- /dev/null +++ b/jinja-sql-client/project/build.properties @@ -0,0 +1 @@ +sbt.version = 1.9.9 diff --git a/jinja-sql-client/src/main/java/module-info.java b/jinja-sql-client/src/main/java/module-info.java new file mode 100644 index 000000000..fe42f472b --- /dev/null +++ b/jinja-sql-client/src/main/java/module-info.java @@ -0,0 +1,23 @@ +/* + * Copyright 2024 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +module raw.client.jinja.sql { + requires scala.library; + requires raw.client; + requires raw.utils; + requires raw.sources; + requires org.slf4j; + requires org.graalvm.polyglot; + + provides raw.client.api.CompilerServiceBuilder with + raw.client.jinja.sql.JinjaSqlCompilerServiceBuilder; +} diff --git a/jinja-sql-client/src/main/resources/META-INF/services/raw.client.api.CompilerServiceBuilder b/jinja-sql-client/src/main/resources/META-INF/services/raw.client.api.CompilerServiceBuilder new file mode 100644 index 000000000..2c056f037 --- /dev/null +++ b/jinja-sql-client/src/main/resources/META-INF/services/raw.client.api.CompilerServiceBuilder @@ -0,0 +1 @@ +raw.client.jinja.sql.JinjaSqlCompilerServiceBuilder \ No newline at end of file diff --git a/jinja-sql-client/src/main/resources/python/raw_jinja.py b/jinja-sql-client/src/main/resources/python/raw_jinja.py new file mode 100644 index 000000000..534946795 --- /dev/null +++ b/jinja-sql-client/src/main/resources/python/raw_jinja.py @@ -0,0 +1,143 @@ +from jinja2 import Environment, meta, nodes +from jinja2.ext import Extension +from jinja2.exceptions import TemplateRuntimeError +from markupsafe import Markup + +class RaiseExtension(Extension): + # This is our keyword(s): + tags = set(['raise']) + + # See also: jinja2.parser.parse_include() + def parse(self, parser): + # the first token is the token that started the tag. In our case we + # only listen to "raise" so this will be a name token with + # "raise" as value. We get the line number so that we can give + # that line number to the nodes we insert. + lineno = next(parser.stream).lineno + + # Extract the message from the template + message_node = parser.parse_expression() + + return nodes.CallBlock( + self.call_method('_raise', [message_node], lineno=lineno), + [], [], [], lineno=lineno + ) + + def _raise(self, msg, caller): + raise TemplateRuntimeError(msg) + + +class RawJinjaException(Exception): + + def __init__(self, message): + self._message = message + + def message(self): + return self._message + +class RawEnvironment: + + def __init__(self, scopes, getSecret): + self.scopes = scopes + self._getSecret = getSecret + + def secret(self,s): + return self._getSecret(s) + + +from datetime import datetime, time + +class RawDate(datetime): + pass + +class RawTimestamp(datetime): + pass + +class RawTime(time): + pass + + +class RawJinja: + + def __init__(self): + self._env = Environment(finalize=lambda x: self.fix(x), autoescape=False, extensions=[RaiseExtension]) + self._env.filters['safe'] = self.flag_as_safe + self._env.filters['identifier'] = self.flag_as_identifier + # a default env to make sure 'environment' is predefined + self._env.globals['environment'] = RawEnvironment(None, None) + + def fix(self, val): + if isinstance(val, Markup): + return val + elif isinstance(val, str): + return "'" + val.replace("'", "''") + "'" + elif isinstance(val, RawDate): + return "make_date(%d,%d,%d)" % (val.year, val.month, val.day) + elif isinstance(val, RawTime): + return "make_time(%d,%d,%d + %d / 1000000.0)" % (val.hour, val.minute, val.second, val.microsecond) + elif isinstance(val, RawTimestamp): + return "make_timestamp(%d,%d,%d,%d,%d,%f)" % (val.year, val.month, val.day, val.hour, val.minute, val.second + val.microsecond / 1000000.0) + elif isinstance(val, list): + items = ["(" + self.fix(i) + ")" for i in val] + return "(VALUES " + ",".join(items) + ")" + elif val is None or val == None: + return "NULL" + else: + return val + + def flag_as_safe(self, s): + return Markup(s) + + def flag_as_identifier(self, s): + return Markup('"' + s.replace('"', '""') + '"') + + def _apply(self, code, args): + template = self._env.from_string(code) + return template.render(args) + + def apply(self, code, scopes, secret, args): + adapter = RawJavaPythonAdapter() + d = {key: adapter._toPython(args.get(key)) for key in args.keySet()} + d['environment'] = RawEnvironment([s for s in scopes.iterator()], lambda s: secret.apply(s)) + return self._apply(code, d) + + + + def validate(self, code): + tree = self._env.parse(code) + return list(meta.find_undeclared_variables(tree)) + + def metadata_comments(self, code): + return [content for (line, tipe, content) in self._env.lex(code) if tipe == 'comment'] + + +class RawJavaPythonAdapter: + + def __init__(self): + import java + self.javaLocalDateClass = java.type('java.time.LocalDate') + self.javaLocalTimeClass = java.type('java.time.LocalTime') + self.javaLocalDateTimeClass = java.type('java.time.LocalDateTime') + + def _toPython(self, arg): + if isinstance(arg, self.javaLocalDateClass): + return RawDate(arg.getYear(), arg.getMonthValue(), arg.getDayOfMonth()) + if isinstance(arg, self.javaLocalTimeClass): + return RawTime(arg.getHour(), arg.getMinute(), arg.getSecond(), int(arg.getNano() / 1000)) + if isinstance(arg, self.javaLocalDateTimeClass): + return RawTimestamp(arg.getYear(), arg.getMonthValue(), arg.getDayOfMonth(), + arg.getHour(), arg.getMinute(), arg.getSecond(), int(arg.getNano()/ 1000)) + return arg + + +# rawJinja = RawJinja() +# code = """{% if False %} +# SELECT {{ 1 }} AS v +# {% else %} +# SELECT ** {{ 2 }} AS v +# {% endif %}""" +# result = rawJinja._env.parse(code) +# result2 = rawJinja._env.lex(code) +# result3 = rawJinja._env.preprocess(code) +# result3 = rawJinja._env.from_string(code).render() +# print(result) \ No newline at end of file diff --git a/jinja-sql-client/src/main/resources/python/test_jinja.py b/jinja-sql-client/src/main/resources/python/test_jinja.py new file mode 100644 index 000000000..bed3a640f --- /dev/null +++ b/jinja-sql-client/src/main/resources/python/test_jinja.py @@ -0,0 +1,21 @@ +import raw_jinja + +# def test_basic(): +# rawJinja = raw_jinja.RawJinja() +# result = rawJinja.test() +# print(result) +# assert result == """SELECT ** 2 AS v""" + +import unittest + +class TestJinja(unittest.TestCase): + + def test_upper(self): + rawJinja = raw_jinja.RawJinja() + code = """{% if False %} +SELECT {{ 1 }} AS v +{% else %} +SELECT ** {{ 2 }} AS v +{% endif %}""" + result = rawJinja._apply(code, {}) + unittest.TestCase.assertEqual(self, result, """SELECT ** 2 AS v""") \ No newline at end of file diff --git a/jinja-sql-client/src/main/resources/reference.conf b/jinja-sql-client/src/main/resources/reference.conf new file mode 100644 index 000000000..183f248c0 --- /dev/null +++ b/jinja-sql-client/src/main/resources/reference.conf @@ -0,0 +1,9 @@ +raw.client.jinja-sql { + graalpy { + executable = "/home/ld/python-venvs/graalpy23-venv/bin/python3" + home = "/home/ld/.pyenv/versions/graalpy-community-23.1.0" + core = "/home/ld/.pyenv/versions/graalpy-community-23.1.0/lib/graalpy23.1" + stdLib = "/home/ld/.pyenv/versions/graalpy-community-23.1.0/lib/python3.10" + logging = false + } +} \ No newline at end of file diff --git a/jinja-sql-client/src/main/scala/raw/client/jinja/sql/JinjaSqlCompilerService.scala b/jinja-sql-client/src/main/scala/raw/client/jinja/sql/JinjaSqlCompilerService.scala new file mode 100644 index 000000000..1e24178f0 --- /dev/null +++ b/jinja-sql-client/src/main/scala/raw/client/jinja/sql/JinjaSqlCompilerService.scala @@ -0,0 +1,302 @@ +/* + * Copyright 2024 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package raw.client.jinja.sql + +import org.graalvm.polyglot._ +import org.graalvm.polyglot.io.IOAccess +import raw.client.api._ +import raw.creds.api.CredentialsServiceProvider +import raw.utils.RawSettings + +class Env(val scopes: Value, val secret: String => String) { +// def tralala(s: String) = secret(s) +} + +class JinjaSqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)( + implicit protected val settings: RawSettings +) extends CompilerService { + + private val JINJA_ERROR = "jinjaError" + + private val (engine, _) = CompilerService.getEngine + private val sqlCompilerService = CompilerServiceProvider("sql", maybeClassLoader) + + private val pythonCtx = { + val graalpyExecutable = settings.getString("raw.client.jinja-sql.graalpy.executable") + val graalpyHome = settings.getString("raw.client.jinja-sql.graalpy.home") + val graalpyCore = settings.getString("raw.client.jinja-sql.graalpy.core") + val graalpystdLib = settings.getString("raw.client.jinja-sql.graalpy.stdLib") + val logging = settings.getBoolean("raw.client.jinja-sql.graalpy.logging") + logger.info("pythonExecutable:" + graalpyExecutable) + val builder = Context + .newBuilder("python") + .engine(engine) + // .environment("RAW_SETTINGS", settings.renderAsString) + // .environment("RAW_USER", environment.user.uid.toString) + // .environment("RAW_TRACE_ID", environment.user.uid.toString) + // .environment("RAW_SCOPES", environment.scopes.mkString(",")) + .allowExperimentalOptions(true) + .allowPolyglotAccess(PolyglotAccess.ALL) + .allowIO(IOAccess.ALL) + .allowHostAccess(HostAccess.ALL) + .allowHostClassLoading(true) + .allowHostClassLookup(_ => true) + .allowNativeAccess(true) + .option("python.DontWriteBytecodeFlag", "true") + .option("python.ForceImportSite", "true") // otherwise jinja2 isn't found + .option("python.PythonHome", graalpyHome) + .option("python.CoreHome", graalpyCore) + .option("python.StdLibHome", graalpystdLib) + .option("python.Executable", graalpyExecutable) + maybeClassLoader.foreach(builder.hostClassLoader) + if (logging) builder.option("python.VerboseFlag", "true").option("log.python.level", "NONE") + + builder.build() + } + + private val credentials = CredentialsServiceProvider(maybeClassLoader) + + private val bindings = { + val helper = getClass.getResource("/python/raw_jinja.py") + logger.info(helper.toString) + val truffleSource = Source.newBuilder("python", helper).build() + pythonCtx.eval(truffleSource) + pythonCtx.getBindings("python") + } + + private val rawjinjaClass = bindings.getMember("RawJinja") + assert(rawjinjaClass.canInstantiate) + private val rawJinja = rawjinjaClass.newInstance() + private val apply = rawJinja.getMember("apply") + private val validate = rawJinja.getMember("validate") + private val metadataComments = rawJinja.getMember("metadata_comments") + + def dotAutoComplete( + source: String, + environment: raw.client.api.ProgramEnvironment, + position: raw.client.api.Pos + ): raw.client.api.AutoCompleteResponse = AutoCompleteResponse(Array.empty) + + def execute( + source: String, + environment: raw.client.api.ProgramEnvironment, + maybeDecl: Option[String], + outputStream: java.io.OutputStream + ): raw.client.api.ExecutionResponse = { + val args = new java.util.HashMap[String, Object] + codeArgs(source, environment) match { + case Left(errorMessages) => ExecutionRuntimeFailure(errorMessages.map(_.message).mkString(",")) + case Right(params) => + for (p <- params; value <- p.defaultValue) args.put(p.idn, rawValueToPolyglot(value)) + for (userArgs <- environment.maybeArguments.toArray; (key, v) <- userArgs) args.put(key, rawValueToPolyglot(v)) + val scopes = new java.util.ArrayList[String] + environment.scopes.foreach(scopes.add) + val env = new Env( + Value.asValue(scopes), + (secret: String) => credentials.getSecret(environment.user, secret).map(_.value).orNull + ) + val sqlQuery: String = + try { + apply.execute(pythonCtx.asValue(source), env.scopes, env.secret, args).asString + } catch { + case ex: PolyglotException => handlePolyglotException(ex, source, environment) match { + case Some(errorMessage) => return ExecutionValidationFailure(List(errorMessage)) + case None => throw new CompilerServiceException(ex, environment) + } + } + logger.debug(sqlQuery) + sqlCompilerService.execute(sqlQuery, environment, None, outputStream) + } + } + + private def rawValueToPolyglot(value: RawValue) = value match { + case RawShort(v) => Value.asValue(v) + case RawInt(i) => Value.asValue(i) + case RawLong(v) => Value.asValue(v) + case RawFloat(v) => Value.asValue(v) + case RawDouble(v) => Value.asValue(v) + case RawBool(v) => Value.asValue(v) + case RawString(s) => Value.asValue(s) + case RawDate(d) => Value.asValue(d) + case RawTime(v) => Value.asValue(v) + case RawTimestamp(v) => Value.asValue(v) + case _ => ??? + } + + def eval( + source: String, + tipe: raw.client.api.RawType, + environment: raw.client.api.ProgramEnvironment + ): raw.client.api.EvalResponse = ??? + + def aiValidate(source: String, environment: raw.client.api.ProgramEnvironment): raw.client.api.ValidateResponse = ??? + + def formatCode( + source: String, + environment: raw.client.api.ProgramEnvironment, + maybeIndent: Option[Int], + maybeWidth: Option[Int] + ): raw.client.api.FormatCodeResponse = FormatCodeResponse(None) + + private def codeArgs( + source: String, + environment: raw.client.api.ProgramEnvironment + ): Either[List[ErrorMessage], Vector[ParamDescription]] = { + val unknownArgs = { + try { + validate.execute(pythonCtx.asValue(source)) + } catch { + case ex: PolyglotException => handlePolyglotException(ex, source, environment) match { + case Some(errorMessage) => return Left(List(errorMessage)) + case None => throw new CompilerServiceException(ex, environment) + } + } + } + assert(unknownArgs.hasArrayElements) + + val args = (0L until unknownArgs.getArraySize) + .map(unknownArgs.getArrayElement) + .map(_.asString) + .toVector + .map(s => s -> ParamDescription(s, Some(RawStringType(false, false)), None, None, true)) + .toMap + + val sqlArgs = { + val comments = Value.asValue(metadataComments.execute(pythonCtx.asValue(source))) + val metadata = (0L until comments.getArraySize) + .map(x => comments.getArrayElement(x)) + .map(_.asString()) + .filter(x => x.strip().startsWith("@")) + .map(s => "/*" + s + "*/") + val sqlCode = (metadata :+ "SELECT 1").mkString("\n") + sqlCompilerService.getProgramDescription(sqlCode, environment) match { + case GetProgramDescriptionSuccess(programDescription) => + programDescription.maybeRunnable.get.params.get.map(p => p.idn -> p).toMap + case failure: GetProgramDescriptionFailure => return Left(failure.errors) + } + } + Right((args ++ sqlArgs).values.toVector) + + } + + def getProgramDescription( + source: String, + environment: raw.client.api.ProgramEnvironment + ): raw.client.api.GetProgramDescriptionResponse = { + codeArgs(source, environment) match { + case Left(errorMessages) => GetProgramDescriptionFailure(errorMessages) + case Right(allArgs) => GetProgramDescriptionSuccess( + ProgramDescription( + Map.empty, + Some(DeclDescription(Some(allArgs), Some(RawIterableType(RawAnyType(), false, false)), None)), + None + ) + ) + } + } + + def goToDefinition( + source: String, + environment: raw.client.api.ProgramEnvironment, + position: raw.client.api.Pos + ): raw.client.api.GoToDefinitionResponse = GoToDefinitionResponse(None) + + def hover( + source: String, + environment: raw.client.api.ProgramEnvironment, + position: raw.client.api.Pos + ): raw.client.api.HoverResponse = HoverResponse(None) + + def rename( + source: String, + environment: raw.client.api.ProgramEnvironment, + position: raw.client.api.Pos + ): raw.client.api.RenameResponse = RenameResponse(Array.empty) + + def validate(source: String, environment: raw.client.api.ProgramEnvironment): ValidateResponse = { + { + try { + validate.execute(pythonCtx.asValue(source)) + } catch { + case ex: PolyglotException => handlePolyglotException(ex, source, environment) match { + case Some(errorMessage) => return ValidateResponse(List(errorMessage)) + case None => throw new CompilerServiceException(ex, environment) + } + } + } + ValidateResponse(List.empty) + } + + def wordAutoComplete( + source: String, + environment: raw.client.api.ProgramEnvironment, + prefix: String, + position: raw.client.api.Pos + ): raw.client.api.AutoCompleteResponse = AutoCompleteResponse(Array.empty) + + // Members declared in raw.utils.RawService + + def doStop(): Unit = { + sqlCompilerService.stop() + credentials.stop() + } + + private def handlePolyglotException( + ex: PolyglotException, + source: String, + environment: raw.client.api.ProgramEnvironment + ): Option[ErrorMessage] = { + if (ex.isInterrupted || ex.getMessage.startsWith("java.lang.InterruptedException")) { + throw new InterruptedException() + } else if (ex.getCause.isInstanceOf[InterruptedException]) { + throw ex.getCause + } else if (ex.isGuestException && !ex.isInternalError) { + val guestObject = ex.getGuestObject + val isException = guestObject.isException + assert(isException, s"$guestObject not an Exception!") + val exceptionClass = guestObject.getMetaObject.getMetaSimpleName + exceptionClass match { + case "TemplateSyntaxError" => + val lineno = guestObject.getMember("lineno").asInt() + val message = guestObject.getMember("message").asString() + val location = ErrorPosition(lineno, 1) + val endLocation = ErrorPosition(lineno, source.split('\n')(lineno - 1).length) + val range = ErrorRange(location, endLocation) + Some(ErrorMessage(message, List(range), JINJA_ERROR)) + case "TemplateAssertionError" => + val lineno = guestObject.getMember("lineno").asInt() + val message = guestObject.getMember("message").asString() + val location = ErrorPosition(lineno, 1) + val endLocation = ErrorPosition(lineno, source.split('\n')(lineno - 1).length) + val range = ErrorRange(location, endLocation) + Some(ErrorMessage(message, List(range), JINJA_ERROR)) + case "TemplateRuntimeError" => + val message = guestObject.getMember("message").asString() + Some(ErrorMessage(message, List.empty, JINJA_ERROR)) + case "UndefinedError" => + val message = guestObject.getMember("message").asString() + Some(ErrorMessage(message, List.empty, JINJA_ERROR)) + case _ => throw new CompilerServiceException(ex, environment) + } + } else { + throw ex + } + } + + def build(maybeClassLoader: Option[ClassLoader])( + implicit settings: raw.utils.RawSettings + ): raw.client.api.CompilerService = ??? + + def language: Set[String] = Set("jinja-sql") + +} diff --git a/jinja-sql-client/src/main/scala/raw/client/jinja/sql/JinjaSqlCompilerServiceBuilder.scala b/jinja-sql-client/src/main/scala/raw/client/jinja/sql/JinjaSqlCompilerServiceBuilder.scala new file mode 100644 index 000000000..13db731de --- /dev/null +++ b/jinja-sql-client/src/main/scala/raw/client/jinja/sql/JinjaSqlCompilerServiceBuilder.scala @@ -0,0 +1,24 @@ +/* + * Copyright 2024 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package raw.client.jinja.sql + +import raw.client.api.{CompilerService, CompilerServiceBuilder} +import raw.utils.RawSettings + +class JinjaSqlCompilerServiceBuilder extends CompilerServiceBuilder { + override def language: Set[String] = Set("jinja-sql") + + override def build(maybeClassLoader: Option[ClassLoader])(implicit settings: RawSettings): CompilerService = + new JinjaSqlCompilerService(maybeClassLoader) + +} diff --git a/jinja-sql-client/src/test/scala/raw/client/jinja/sql/PreprocessingTest.scala b/jinja-sql-client/src/test/scala/raw/client/jinja/sql/PreprocessingTest.scala new file mode 100644 index 000000000..e78d5c2bd --- /dev/null +++ b/jinja-sql-client/src/test/scala/raw/client/jinja/sql/PreprocessingTest.scala @@ -0,0 +1,438 @@ +/* + * Copyright 2024 RAW Labs S.A. + * + * Use of this software is governed by the Business Source License + * included in the file licenses/BSL.txt. + * + * As of the Change Date specified in that file, in accordance with + * the Business Source License, use of this software will be governed + * by the Apache License, Version 2.0, included in the file + * licenses/APL.txt. + */ + +package raw.client.jinja.sql + +import org.scalatest.matchers.must.Matchers.{be, contain} +import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper +import org.scalatest.matchers.{MatchResult, Matcher} +import raw.client.api.{GetProgramDescriptionSuccess, _} +import raw.creds.api.CredentialsTestContext +import raw.creds.local.LocalCredentialsTestContext +import raw.utils._ + +import java.io.ByteArrayOutputStream +import java.time.{LocalDate, LocalDateTime, LocalTime} + +class PreprocessingTest + extends RawTestSuite + with SettingsTestContext + with TrainingWheelsContext + with CredentialsTestContext + with LocalCredentialsTestContext { + + private val database = sys.env.getOrElse("FDW_DATABASE", "unittest") + private val hostname = sys.env.getOrElse("FDW_HOSTNAME", "localhost") + private val port = sys.env.getOrElse("FDW_HOSTNAME", "5432") + private val username = sys.env.getOrElse("FDW_USERNAME", "postgres") + private val password = sys.env.getOrElse("FDW_PASSWORD", "") + + property("raw.creds.jdbc.fdw.host", hostname) + property("raw.creds.jdbc.fdw.port", port) + property("raw.creds.jdbc.fdw.user", username) + property("raw.creds.jdbc.fdw.password", password) + + private case class Q(code: String, args: Map[String, Any] = Map.empty) { + def withArgs(values: Map[String, Any]): Q = Q(code, values) + def withArg(arg: (String, Any)): Q = withArgs(Map(arg)) + + def description(): GetProgramDescriptionResponse = { + compilerService.getProgramDescription( + code, + environment = ProgramEnvironment(user, None, Set.empty, Map("output-format" -> "json")) + ) + } + + } + + private class Give(jsonString: String) extends Matcher[Q] { + override def apply(left: Q): MatchResult = { + val rawArgs: Array[(String, RawValue)] = left.args.mapValues { + case s: String => RawString(s) + case i: Int => RawInt(i) + case d: LocalDate => RawDate(d) + case v: LocalDateTime => RawTimestamp(v) + case v: LocalTime => RawTime(v) + }.toArray + val env = ProgramEnvironment(user, Some(rawArgs), Set.empty, Map("output-format" -> "json")) + val baos = new ByteArrayOutputStream() + val r = compilerService.execute(left.code, env, None, baos) + r match { + case ExecutionSuccess => + val result = baos.toString + MatchResult(result == jsonString, s"Actual: $result\nExpected: $jsonString", "ok") + case ExecutionRuntimeFailure(error) => MatchResult(false, error, "ok") + case ExecutionValidationFailure(errors) => MatchResult(false, errors.map(_.message).mkString(","), "ok") + } + } + } + + private class FailWith(errorMessage: String) extends Matcher[Q] { + override def apply(left: Q): MatchResult = { + val rawArgs: Array[(String, RawValue)] = left.args.mapValues { + case s: String => RawString(s) + case i: Int => RawInt(i) + }.toArray + val env = ProgramEnvironment(user, Some(rawArgs), Set.empty, Map("output-format" -> "json")) + val baos = new ByteArrayOutputStream() + val r = compilerService.execute(left.code, env, None, baos) + r match { + case ExecutionRuntimeFailure(error) => MatchResult(error.contains(errorMessage), error, "ok") + case ExecutionValidationFailure(errors) => + MatchResult(errors.exists(_.message.contains(errorMessage)), errors.map(_.message).mkString(","), "ok") + case ExecutionSuccess => MatchResult(false, "didn't fail", "ok") + } + } + } + + private class Validate extends Matcher[Q] { + override def apply(left: Q): MatchResult = { + val env = ProgramEnvironment(user, None, Set.empty, Map("output-format" -> "json")) + val ValidateResponse(validationErrors) = compilerService.validate(left.code, env) + compilerService.getProgramDescription(left.code, env) match { + case GetProgramDescriptionFailure(descriptionErrors) => + MatchResult(false, (validationErrors ++ descriptionErrors).mkString(","), "ok") + case _: GetProgramDescriptionSuccess => + MatchResult(validationErrors.isEmpty, validationErrors.mkString(","), "ok") + } + MatchResult(validationErrors.isEmpty, validationErrors.mkString(","), "ok") + } + } + + private def give(jsonContent: String) = new Give(jsonContent) + private def failWith(error: String) = new FailWith(error) + private def validate = new Validate + + test("safe vs. unsafe") { _ => + // without '|safe' the string variable is turned into a quoted string + val code = Q("SELECT MIN({{ column }}) FROM example.airports WHERE country = {{ country }}") + code withArgs Map("column" -> "latitude", "country" -> "France") should give("""[{"min":"latitude"}]""") + code withArgs Map("column" -> "longitude", "country" -> "France") should give("""[{"min":"longitude"}]""") + + // with '|safe' the string variable is pasted untouched/unquoted (e.g. here, it's used as a column name) + val code2 = Q("SELECT MAX({{ column|safe }}) FROM example.airports WHERE country = {{ country }}") + code2 withArgs Map("column" -> "latitude", "country" -> "France") should give("""[{"max":50.967536}]""") + code2 withArgs Map("column" -> "longitude", "country" -> "France") should give("""[{"max":9.483731}]""") + + // '|safe' used when computing a local variable, which is then used as safe even though it's not flagged directly, + val code3 = Q("""{% set col = column|safe %} + |SELECT MAX({{ col }}) FROM example.airports WHERE country = {{ country }}""".stripMargin) + code3 withArgs Map("column" -> "latitude", "country" -> "France") should give("""[{"max":50.967536}]""") + code3 withArgs Map("column" -> "longitude", "country" -> "France") should give("""[{"max":9.483731}]""") + + } + + test("escape") { _ => + Q("SELECT {{ v }} AS v") withArg ("v" -> "
") should give("""[{"v":""}]""") + Q("SELECT '{{ v|safe }}' AS v") withArg ("v" -> "") should give("""[{"v":""}]""") + Q("SELECT {{ v }} AS v") withArg ("v" -> "''") should give("""[{"v":"''"}]""") + Q("SELECT '{{ v }}' AS v") withArg ("v" -> "") should failWith("column \"p\" does not exist") + Q("SELECT {{ v }} AS v") withArg ("v" -> "\"\"") should give("""[{"v":"\"\""}]""") + Q("SELECT {{ v }} AS v") withArg ("v" -> "1 + 2") should give("""[{"v":"1 + 2"}]""") // quoted by default + Q("SELECT {{ v|safe }} AS v") withArg ("v" -> "1 + 2") should give( + """[{"v":3}]""" + ) // pasted straight => interpreted + } + + /* + -- @param v the date + -- of the event + SELECT 1; + */ + + test("""integer parameter""") { _ => + val code = Q(s""" + |{# @type v integer #} + |{# @default v 12 #} + |{# @param v a random number + | here to test something #} + |SELECT {{ v }} * 10 AS r + |""".stripMargin) + val GetProgramDescriptionSuccess(d) = code.description() + d.decls.size should be(0) + val Vector(param) = d.maybeRunnable.get.params.get + param.idn should be("v") + param.required should be(false) // because we have a default + param.tipe.get should be(RawIntType(true, false)) // null is always OK + param.comment.get should be("a random number here to test something") + code withArg "v" -> 22 should give("""[{"r":220}]""") + code withArg "v" -> "tralala" should failWith("invalid input syntax for type integer") + } + + test("""date parameter (with default)""") { _ => + val code = Q(s""" + |{# @type v date #} + |{# @default v '2001-01-01' #} + |{# @param v a random date + | here to test something #} + |SELECT EXTRACT('YEAR' FROM {{ v }}) AS r + |""".stripMargin) + val GetProgramDescriptionSuccess(d) = code.description() + d.decls.size should be(0) + val Vector(param) = d.maybeRunnable.get.params.get + param.idn should be("v") + param.required should be(false) // because we have a default + param.tipe.get should be(RawDateType(true, false)) // null is always OK + param.comment.get should be("a random date here to test something") + code withArg "v" -> LocalDate.of(2024, 1, 1) should give("""[{"r":2024.0}]""") + code withArgs Map.empty should give("""[{"r":2001.0}]""") + } + + test("""timestamp parameter (with default)""") { _ => + val code = Q(s""" + |{# @type v timestamp #} + |{# @default v '2001-01-01 12:34:56.099' #} + |{# @param v a random timestamp + | here to test something #} + |SELECT {{ v }} AS r + |""".stripMargin) + val GetProgramDescriptionSuccess(d) = code.description() + d.decls.size should be(0) + val Vector(param) = d.maybeRunnable.get.params.get + param.idn should be("v") + param.required should be(false) // because we have a default + param.tipe.get should be(RawTimestampType(true, false)) // null is always OK + param.comment.get should be("a random timestamp here to test something") + code withArg "v" -> LocalDateTime.of(2024, 1, 1, 1, 1, 1, 3309000) should give( + """[{"r":"2024-01-01T01:01:01.003"}]""" + ) + code withArgs Map.empty should give("""[{"r":"2001-01-01T12:34:56.099"}]""") + } + + test("""time parameter (with default)""") { _ => + val code = Q(s""" + |{# @type v time #} + |{# @default v '12:34:56.099' #} + |{# @param v a random time + | here to test something #} + |SELECT {{ v }} AS r + |""".stripMargin) + val GetProgramDescriptionSuccess(d) = code.description() + d.decls.size should be(0) + val Vector(param) = d.maybeRunnable.get.params.get + param.idn should be("v") + param.required should be(false) // because we have a default + param.tipe.get should be(RawTimeType(true, false)) // null is always OK + param.comment.get should be("a random time here to test something") + code withArg "v" -> LocalTime.of(1, 1, 1, 3309000) should give("""[{"r":"01:01:01.000"}]""") + code withArgs Map.empty should give("""[{"r":"12:34:56.000"}]""") + } + + test("""date expression""") { _ => + val code = Q(s""" + |{# @type v date #} + |{# @param v a random date + | here to test something #} + |SELECT {{ v.replace(year=2001) }} AS r + |""".stripMargin) + val GetProgramDescriptionSuccess(d) = code.description() + d.decls.size should be(0) + val Vector(param) = d.maybeRunnable.get.params.get + param.idn should be("v") + param.required should be(true) // because we have a default + param.tipe.get should be(RawDateType(true, false)) // null is always OK + param.comment.get should be("a random date here to test something") + code withArg "v" -> LocalDate.of(2024, 1, 1) should give("""[{"r":"2001-01-01"}]""") + code withArg "v" -> LocalDate.of(1978, 3, 8) should give("""[{"r":"2001-03-08"}]""") + } + + test("""SELECT {{ column|safe + "y" }}, COUNT(*) + |FROM example.airports + |GROUP BY {{ column|safe + "y" }} + |ORDER BY COUNT(*) DESC + |LIMIT 3 + |""".stripMargin) { t => + val q = Q(t.q) + q should validate + q withArg ("column" -> "cit") should give( + """[{"city":"London","count":21},{"city":"New York","count":13},{"city":"Hong Kong","count":12}]""" + ) + q withArg ("column" -> "countr") should give( + """[{"country":"United States","count":1697},{"country":"Canada","count":435},{"country":"Germany","count":321}]""" + ) + } + + test(s""" + |SELECT COUNT(*) AS n + |FROM example.airports + |WHERE + |{% if key == "country" %} + | country + |{% elif key == "city" %} + | city + |{% else %} + | {% raise "error, unknown key: " + key %} + |{% endif %} + | = {{ value }} + |""".stripMargin) { t => + val q = Q(t.q) + q should validate + val GetProgramDescriptionSuccess(description) = q.description() + val argNames = description.maybeRunnable.get.params.get.map(_.idn) + argNames should (contain("key") and contain("value")) + q withArgs Map("key" -> "city", "value" -> "Athens") should give("""[{"n":6}]""") + q withArgs Map("key" -> "country", "value" -> "Greece") should give("""[{"n":60}]""") + q withArgs Map("key" -> "iata", "value" -> "GVA") should failWith("error, unknown key: iata") + } + + ignore("""SELECT {{ key }}, + | COUNT(*), + | SUM({{ raw:fail() }}) + |FROM example.airports + |GROUP BY {{ key }} + |""".stripMargin) { q => + val v = compilerService.validate(q.q, asJson()) + assert(v != null) + val baos = new ByteArrayOutputStream() + val r = compilerService.execute(q.q, asJson(Map("key" -> RawString("city"))), None, baos) + assert(r == ExecutionSuccess) + } + + test("""SELECT {{ 1 - 2 }} AS v""".stripMargin) { t => + val q = Q(t.q) + q should validate + q should give("""[{"v":-1}]""") + } + + test("""SELECT {{ a }} + {{ b }} AS v + |""".stripMargin) { t => + val q = Q(t.q) + q should validate + val GetProgramDescriptionSuccess(d) = q.description() + d.decls.size should be(0) + val names = d.maybeRunnable.get.params.get.map(_.idn) + names should (contain("a") and contain("b")) + q withArgs (Map("a" -> 1, "b" -> 2)) should give("""[{"v":3}]""") + q withArgs (Map("a" -> 11, "b" -> 22)) should give("""[{"v":33}]""") + } + + test("""SELECT {{ a + 1 - b }} AS v + |""".stripMargin) { t => + val q = Q(t.q) + q should validate + val GetProgramDescriptionSuccess(d) = q.description() + d.decls.size should be(0) + val names = d.maybeRunnable.get.params.get.map(_.idn) + names should (contain("a") and contain("b")) + q withArgs (Map("a" -> 1, "b" -> 2)) should give("""[{"v":0}]""") + q withArgs (Map("a" -> 11, "b" -> 22)) should give("""[{"v":-10}]""") + } + + test("""SELECT {{ 1 +> 2 }} + |""".stripMargin) { t => + val q = Q(t.q) + assert(compilerService.validate(t.q, asJson()).messages.map(_.message).exists(_.contains("unexpected '>'"))) + q should failWith("unexpected '>'") + } + + ignore(s""" + |{% set v = val == "latitude" ? "latitude" : "longitude" %} + |SELECT {{ key }}, MAX({{ v }}), MIN({{ v }}) + |FROM example.airports GROUP BY {{ key }} + |ORDER BY COUNT(*) {{ order }} + |LIMIT 3 + |""".stripMargin) { q => + val g = compilerService.getProgramDescription(q.q, asJson()) + assert(g != null) + val baos = new ByteArrayOutputStream() + val r = compilerService.execute( + q.q, + asJson(Map("key" -> RawString("country"), "val" -> RawString("latitude"), "order" -> RawString("DESC"))), + None, + baos + ) + assert(r == ExecutionSuccess) + + } + + test(s""" + |{% set v = "latitude" if val == "latitude" else "longitude" %} + |SELECT {{ key|safe }}, MAX({{ v|safe }}), MIN({{ v|safe }}) + |FROM example.airports GROUP BY {{ key|safe }} + |ORDER BY COUNT(*) {{ order|safe }} + |LIMIT 3 + |""".stripMargin) { t => + val q = Q(t.q) + q should validate + q withArgs Map( + "key" -> "country", + "val" -> "latitude", + "order" -> "DESC" + ) should give( + """[{"country":"United States","max":72.270833,"min":-1.111100},{"country":"Canada","max":82.517778,"min":42.199000},{"country":"Germany","max":54.913250,"min":0.000000}]""" + ) + } + + test(s""" + |SELECT {{key}}, + | COUNT(*), + | SUM( + | {%if sumRow == "Quantity" %} quantity + | {%elif sumRow == "NetAmount" %} netamount + | {%else %} {% raise "unknown sum: " + sumRow %} + | {%endif %} + | ) + | {% if moreColumns == true %} , MAX(latitude) {% endif %} + |FROM lokad_orders + |GROUP BY + | {%if key == "year" %} YEAR(date) + | {%elif key == "month" %} MONTH(date) + | {%elif key == "Quantity" %} quantity + | {%elif key == "Currency" %} currency + | {%else %} {% raise "unknown key:" + key %} + | {%endif %} AS {{key}} + |""".stripMargin) { t => + val q = Q(t.q) + q should validate + } + + test("scopes") { _ => + val q = Q("SELECT {{ environment.scopes|length }} AS n") + val GetProgramDescriptionSuccess(description) = q.description() + assert(description.maybeRunnable.get.params.exists(_.isEmpty)) + q should give("""[{"n":0}]""") + } + + test("secret") { _ => + val q = Q("""SELECT {{ environment.secret("blah")}} AS n""") + val GetProgramDescriptionSuccess(description) = q.description() + assert(description.maybeRunnable.get.params.exists(_.isEmpty)) + q should give("""[{"n":null}]""") + } + + test("SELECT airport_id, {{ c }} FROM {{ }} ") { q => + val v = compilerService.validate(q.q, asJson()) + assert(v != null) + } + + private var compilerService: CompilerService = _ + + private val user = InteractiveUser(Uid(database), "fdw user", "email", Seq.empty) + + private def asJson(params: Map[String, RawValue] = Map.empty): ProgramEnvironment = { + if (params.isEmpty) ProgramEnvironment(user, None, Set.empty, Map("output-format" -> "json")) + else ProgramEnvironment(user, Some(params.toArray), Set.empty, Map("output-format" -> "json")) + } + + override def beforeAll(): Unit = { + super.beforeAll() + compilerService = new JinjaSqlCompilerService(None) + } + + override def afterAll(): Unit = { + if (compilerService != null) { + compilerService.stop() + compilerService = null + } + } + +} diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 004fb5e81..97e6f4d20 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -47,6 +47,7 @@ object Dependencies { // from client val trufflePolyglot = "org.graalvm.polyglot" % "polyglot" % "23.1.0" + val trufflePython = "org.graalvm.polyglot" % "python-community" % "23.1.0" // from snapi-parser val antlr4Runtime = "org.antlr" % "antlr4-runtime" % "4.12.0" @@ -79,7 +80,6 @@ object Dependencies { val jwtCore = "com.github.jwt-scala" %% "jwt-core" % "9.4.4-rawlabs" val springCore = "org.springframework" % "spring-core" % "5.3.13" val truffleCompiler = Seq( - "org.graalvm.truffle" % "truffle-api" % "23.1.0", "org.graalvm.truffle" % "truffle-api" % "23.1.0", "org.graalvm.truffle" % "truffle-compiler" % "23.1.0", "org.graalvm.truffle" % "truffle-nfi" % "23.1.0", diff --git a/python-client/src/main/scala/raw/client/python/PythonCompilerService.scala b/python-client/src/main/scala/raw/client/python/PythonCompilerService.scala index 98eab8c62..f0ffda2fd 100644 --- a/python-client/src/main/scala/raw/client/python/PythonCompilerService.scala +++ b/python-client/src/main/scala/raw/client/python/PythonCompilerService.scala @@ -12,42 +12,9 @@ package raw.client.python -import org.graalvm.polyglot.{Context, Engine, PolyglotAccess, PolyglotException, Source, Value} -import raw.client.api.{ - AutoCompleteResponse, - CompilerService, - CompilerServiceException, - EvalResponse, - EvalRuntimeFailure, - EvalSuccess, - ExecutionResponse, - ExecutionRuntimeFailure, - ExecutionSuccess, - FormatCodeResponse, - GetProgramDescriptionResponse, - GoToDefinitionResponse, - HoverResponse, - Pos, - ProgramEnvironment, - RawBool, - RawByte, - RawDate, - RawDecimal, - RawDouble, - RawFloat, - RawInt, - RawInterval, - RawLong, - RawNull, - RawShort, - RawString, - RawTime, - RawTimestamp, - RawType, - RawValue, - RenameResponse, - ValidateResponse -} +import org.graalvm.polyglot.io.IOAccess +import org.graalvm.polyglot.{Context, Engine, HostAccess, PolyglotAccess, PolyglotException, Source, Value} +import raw.client.api.{AutoCompleteResponse, CompilerService, CompilerServiceException, EvalResponse, EvalRuntimeFailure, EvalSuccess, ExecutionResponse, ExecutionRuntimeFailure, ExecutionSuccess, FormatCodeResponse, GetProgramDescriptionResponse, GoToDefinitionResponse, HoverResponse, Pos, ProgramEnvironment, RawBool, RawByte, RawDate, RawDecimal, RawDouble, RawFloat, RawInt, RawInterval, RawLong, RawNull, RawShort, RawString, RawTime, RawTimestamp, RawType, RawValue, RenameResponse, ValidateResponse} import raw.client.writers.{PolyglotBinaryWriter, PolyglotCsvWriter, PolyglotJsonWriter, PolyglotTextWriter} import raw.utils.{RawSettings, RawUtils} @@ -328,6 +295,9 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean), maybeClassLoade .environment("RAW_SCOPES", environment.scopes.mkString(",")) .allowExperimentalOptions(true) .allowPolyglotAccess(PolyglotAccess.ALL) + .allowIO(IOAccess.ALL) + .allowHostAccess(HostAccess.ALL) + .allowNativeAccess(true) maybeOutputStream.foreach(os => ctxBuilder.out(os)) val ctx = ctxBuilder.build() ctx diff --git a/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2TruffleCompilerService.scala b/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2TruffleCompilerService.scala index 13adb38ea..0bc89f456 100644 --- a/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2TruffleCompilerService.scala +++ b/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2TruffleCompilerService.scala @@ -15,6 +15,7 @@ package raw.client.rql2.truffle import org.bitbucket.inkytonik.kiama.relation.LeaveAlone import org.bitbucket.inkytonik.kiama.util.{Position, Positions} import org.graalvm.polyglot._ +import org.graalvm.polyglot.io.IOAccess import raw.client.api._ import raw.client.rql2.api._ import raw.client.writers.{PolyglotBinaryWriter, PolyglotTextWriter} @@ -698,6 +699,9 @@ class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean), maybeClass .environment("RAW_SCOPES", environment.scopes.mkString(",")) .allowExperimentalOptions(true) .allowPolyglotAccess(PolyglotAccess.ALL) + .allowIO(IOAccess.ALL) + .allowHostAccess(HostAccess.ALL) + .allowNativeAccess(true) environment.options.get("staged-compiler").foreach { stagedCompiler => ctxBuilder.option("rql.staged-compiler", stagedCompiler) } diff --git a/snapi-frontend/src/main/scala/raw/creds/local/LocalCredentialsService.scala b/snapi-frontend/src/main/scala/raw/creds/local/LocalCredentialsService.scala index fee6e5add..36e8fd08f 100644 --- a/snapi-frontend/src/main/scala/raw/creds/local/LocalCredentialsService.scala +++ b/snapi-frontend/src/main/scala/raw/creds/local/LocalCredentialsService.scala @@ -146,7 +146,7 @@ class LocalCredentialsService extends CredentialsService { } override def getUserDb(user: AuthenticatedUser): String = { - "default-user-db" + "unittest" } override def doStop(): Unit = {}