Skip to content

Commit

Permalink
Using Jinjava
Browse files Browse the repository at this point in the history
  • Loading branch information
bgaidioz committed May 10, 2024
1 parent ee2cbdc commit c4d937b
Show file tree
Hide file tree
Showing 10 changed files with 515 additions and 4 deletions.
17 changes: 16 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ lazy val root = (project in file("."))
snapiTruffle,
snapiClient,
sqlParser,
sqlClient
sqlClient,
jinjaSqlClient
)
.settings(
commonSettings,
Expand Down Expand Up @@ -380,3 +381,17 @@ lazy val pythonClient = (project in file("python-client"))
Compile / packageBin / packageOptions += Package.ManifestAttributes("Automatic-Module-Name" -> "raw.python.client"),
libraryDependencies += "org.graalvm.polyglot" % "python" % "23.1.0" % Provided
)

lazy val jinjaSqlClient = (project in file("jinja-sql-client"))
.dependsOn(
client % "compile->compile;test->test",
sqlClient % "compile->compile;test->test",
)
.settings(
commonSettings,
missingInterpolatorCompileSettings,
testSettings,
libraryDependencies ++= Seq(
jinjava
)
)
8 changes: 6 additions & 2 deletions deps/others/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ val jacksonModuleScala = "com.fasterxml.jackson.module" %% "jackson-module-scala

val mysqlModule = "com.mysql" % "mysql-connector-j" % "8.1.0"

val jinjava = "com.hubspot.jinjava" % "jinjava" % "2.7.2" exclude("com.google.code.findbugs", "annotations")

libraryDependencies ++= Seq(
jwtCore,
scalaLogging,
jacksonModuleScala,
mysqlModule
mysqlModule,
jinjava
)

// Map of artifact ID to module name
Expand All @@ -26,6 +29,7 @@ val moduleNames = Map(
"scala-logging" -> "typesafe.scalalogging",
"jackson-module-scala" -> "com.fasterxml.jackson.scala",
"mysql-connector-j" -> "mysql.connector.j",
"jinjava" -> "jinjava"
)

def updatePom(pomFile: File, newVersion: String): Unit = {
Expand Down Expand Up @@ -157,4 +161,4 @@ createS3SyncScript := {

// Notify that the task is completed
println(s"Bash script created: $scriptFile")
}
}
1 change: 1 addition & 0 deletions jinja-sql-client/project/build.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version = 1.9.9
25 changes: 25 additions & 0 deletions jinja-sql-client/src/main/java/module-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
/*
* Copyright 2024 RAW Labs S.A.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/

module raw.client.jinja.sql {
requires scala.library;
requires raw.client;
requires raw.utils;
requires jinjava;
requires org.slf4j;

opens raw.client.jinja.sql to
jinjava;

provides raw.client.api.CompilerServiceBuilder with
raw.client.jinja.sql.JinjaSqlCompilerServiceBuilder;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
raw.client.jinja.sql.JinjaSqlCompilerServiceBuilder
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
/*
* Copyright 2024 RAW Labs S.A.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/

package raw.client.jinja.sql

import com.hubspot.jinjava.interpret.TemplateError.{ErrorReason, ErrorType}
import com.hubspot.jinjava.interpret.{JinjavaInterpreter, TemplateError}
import com.hubspot.jinjava.lib.fn.ELFunctionDefinition
import com.hubspot.jinjava.lib.tag.Tag
import com.hubspot.jinjava.tree.TagNode
import com.hubspot.jinjava.{Jinjava, JinjavaConfig}
import raw.client.api._
import raw.utils.RawSettings

import java.util
import scala.collection.mutable

class JinjaSqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(
implicit protected val settings: RawSettings
) extends CompilerService {

private val sqlCompilerService = CompilerServiceProvider("sql", maybeClassLoader)

def dotAutoComplete(
source: String,
environment: raw.client.api.ProgramEnvironment,
position: raw.client.api.Pos
): raw.client.api.AutoCompleteResponse = AutoCompleteResponse(Array.empty)

private val validationConfigBuilder = JinjavaConfig
.newBuilder()
.withValidationMode(true)
.withMaxOutputSize(10000)
.withFailOnUnknownTokens(false)
.withNestedInterpretationEnabled(false)

private val validationConfig = validationConfigBuilder.build()

private val executionConfig = validationConfigBuilder.withFailOnUnknownTokens(true).build()

private val validateJinjava = new Jinjava(validationConfig)
private val executeJinjava = new Jinjava(executionConfig)

private object FailFunc {
def doFail(): String = {
"doFail"
}
def dontFail(): String = {
"dontFail"
}
}

private class RaiseTag(doRaise: Boolean) extends Tag {

override def interpret(tagNode: TagNode, interpreter: JinjavaInterpreter): String = {
if (doRaise) {
val o = interpreter.resolveELExpression(tagNode.getHelpers, tagNode.getLineNumber, tagNode.getStartPosition)
val error = new TemplateError(ErrorType.FATAL, ErrorReason.EXCEPTION, o.asInstanceOf[String], null, 0, 0, null)
interpreter.addError(error)
}
s"<$getName>"
}

override def getEndTagName: String = null

override def getName: String = "raise"
}

private class ParamTag() extends Tag {

override def interpret(tagNode: TagNode, interpreter: JinjavaInterpreter): String = {
s"-- <$getName>"
}

override def getEndTagName: String = null

override def getName: String = "param"
}

private class TypeTag() extends Tag {

override def interpret(tagNode: TagNode, interpreter: JinjavaInterpreter): String = {
s"-- <$getName>"
}

override def getEndTagName: String = null

override def getName: String = "type"
}

private class DefaultTag() extends Tag {

override def interpret(tagNode: TagNode, interpreter: JinjavaInterpreter): String = {
s"-- <$getName>"
}

override def getEndTagName: String = null

override def getName: String = "default"
}

private val doFail = new ELFunctionDefinition("raw", "fail", FailFunc.getClass, "doFail")
private val dontFail = new ELFunctionDefinition("raw", "fail", FailFunc.getClass, "dontFail")

validateJinjava.getGlobalContext.registerFunction(dontFail)
executeJinjava.getGlobalContext.registerFunction(doFail)

validateJinjava.getGlobalContext.registerTag(new RaiseTag(doRaise = false))
executeJinjava.getGlobalContext.registerTag(new RaiseTag(doRaise = true))
for (tag <- List(new ParamTag(), new DefaultTag(), new TypeTag())) {
validateJinjava.registerTag(tag)
executeJinjava.registerTag(tag)
}

def execute(
source: String,
environment: raw.client.api.ProgramEnvironment,
maybeDecl: Option[String],
outputStream: java.io.OutputStream
): raw.client.api.ExecutionResponse = {
logger.debug("execute")
val arguments = new util.HashMap[String, Object]()
environment.maybeArguments.foreach(_.foreach { case (k, v) => arguments.put(k, rawValueToString(v)) })
val result = executeJinjava.renderForResult(source, arguments)
if (result.getErrors.isEmpty) {
val processed = result.getOutput
sqlCompilerService.execute(processed, environment, None, outputStream)
} else {
ExecutionValidationFailure(asMessages(result.getErrors))
}
}

private def rawValueToString(value: RawValue) = value match {
case RawString(v) => v
case _ => ???
}

def eval(
source: String,
tipe: raw.client.api.RawType,
environment: raw.client.api.ProgramEnvironment
): raw.client.api.EvalResponse = ???

def aiValidate(source: String, environment: raw.client.api.ProgramEnvironment): raw.client.api.ValidateResponse = ???

def formatCode(
source: String,
environment: raw.client.api.ProgramEnvironment,
maybeIndent: Option[Int],
maybeWidth: Option[Int]
): raw.client.api.FormatCodeResponse = FormatCodeResponse(None)

def getProgramDescription(
source: String,
environment: raw.client.api.ProgramEnvironment
): raw.client.api.GetProgramDescriptionResponse = {
val unknownVariables = mutable.Set.empty[String]
validateJinjava.getGlobalContext.setDynamicVariableResolver(v => {
unknownVariables.add(v); ""
})
val result = validateJinjava.renderForResult(source, new java.util.HashMap)
if (result.hasErrors) {
GetProgramDescriptionFailure(asMessages(result.getErrors))
} else {
GetProgramDescriptionSuccess(
ProgramDescription(
Map.empty,
Some(
DeclDescription(
Some(
unknownVariables.map(ParamDescription(_, Some(RawStringType(false, false)), None, None, true)).toVector
),
Some(RawIterableType(RawAnyType(), false, false)),
None
)
),
None
)
)
}
}

def goToDefinition(
source: String,
environment: raw.client.api.ProgramEnvironment,
position: raw.client.api.Pos
): raw.client.api.GoToDefinitionResponse = GoToDefinitionResponse(None)

def hover(
source: String,
environment: raw.client.api.ProgramEnvironment,
position: raw.client.api.Pos
): raw.client.api.HoverResponse = HoverResponse(None)

def rename(
source: String,
environment: raw.client.api.ProgramEnvironment,
position: raw.client.api.Pos
): raw.client.api.RenameResponse = RenameResponse(Array.empty)

def validate(source: String, environment: raw.client.api.ProgramEnvironment): ValidateResponse = {
val result = validateJinjava.renderForResult(source, new java.util.HashMap())
val errors = asMessages(result.getErrors)
ValidateResponse(errors)
}

private def asMessages(errors: util.List[TemplateError]): List[ErrorMessage] = {
val errorMessages = mutable.ArrayBuffer.empty[ErrorMessage]
errors.forEach { error =>
val message = error.getMessage
val line = error.getLineno
val column = error.getStartPosition
val range = ErrorRange(ErrorPosition(line, column), ErrorPosition(line, column + 1))
val errorMessage = ErrorMessage(message, List(range), "")
errorMessages += errorMessage
}
errorMessages.toList

}

def wordAutoComplete(
source: String,
environment: raw.client.api.ProgramEnvironment,
prefix: String,
position: raw.client.api.Pos
): raw.client.api.AutoCompleteResponse = AutoCompleteResponse(Array.empty)

// Members declared in raw.utils.RawService

def doStop(): Unit = {
sqlCompilerService.stop()
}

def build(maybeClassLoader: Option[ClassLoader])(
implicit settings: raw.utils.RawSettings
): raw.client.api.CompilerService = ???

def language: Set[String] = Set("jinja-sql")

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright 2024 RAW Labs S.A.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/

package raw.client.jinja.sql

import raw.client.api.{CompilerService, CompilerServiceBuilder}
import raw.utils.RawSettings

class JinjaSqlCompilerServiceBuilder extends CompilerServiceBuilder {
override def language: Set[String] = Set("jinja-sql")

override def build(maybeClassLoader: Option[ClassLoader])(implicit settings: RawSettings): CompilerService =
new JinjaSqlCompilerService(maybeClassLoader)

}
Loading

0 comments on commit c4d937b

Please sign in to comment.