Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Jinja python alex #438

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ lazy val root = (project in file("."))
snapiTruffle,
snapiClient,
sqlParser,
sqlClient
sqlClient,
jinjaSqlClient
)
.settings(
commonSettings,
Expand Down Expand Up @@ -282,5 +283,18 @@ lazy val pythonClient = (project in file("python-client"))
missingInterpolatorCompileSettings,
testSettings,
Compile / packageBin / packageOptions += Package.ManifestAttributes("Automatic-Module-Name" -> "raw.python.client"),
libraryDependencies += "org.graalvm.polyglot" % "python" % "23.1.0" % Provided
libraryDependencies += trufflePython
)

lazy val jinjaSqlClient = (project in file("jinja-sql-client"))
.dependsOn(
client % "compile->compile;test->test",
sqlClient % "compile->compile;test->test",
)
.settings(
commonSettings,
missingInterpolatorCompileSettings,
testSettings,
libraryDependencies += trufflePython
)

6 changes: 5 additions & 1 deletion client/src/main/scala/raw/client/api/CompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,11 @@ object CompilerService {
// options.put("engine.CompilationFailureAction", "Diagnose")
// options.put("compiler.LogInlinedTargets", "true")
}
val engine = Engine.newBuilder().allowExperimentalOptions(true).options(options).build()
val engine = Engine
.newBuilder()
.allowExperimentalOptions(true)
.options(options)
.build()
enginesCache.put(settings, engine)
(engine, true)
}
Expand Down
1 change: 1 addition & 0 deletions jinja-sql-client/project/build.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
sbt.version = 1.9.9
23 changes: 23 additions & 0 deletions jinja-sql-client/src/main/java/module-info.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/*
* Copyright 2024 RAW Labs S.A.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/

module raw.client.jinja.sql {
requires scala.library;
requires raw.client;
requires raw.utils;
requires raw.sources;
requires org.slf4j;
requires org.graalvm.polyglot;

provides raw.client.api.CompilerServiceBuilder with
raw.client.jinja.sql.JinjaSqlCompilerServiceBuilder;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
raw.client.jinja.sql.JinjaSqlCompilerServiceBuilder
143 changes: 143 additions & 0 deletions jinja-sql-client/src/main/resources/python/raw_jinja.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
from jinja2 import Environment, meta, nodes
from jinja2.ext import Extension
from jinja2.exceptions import TemplateRuntimeError
from markupsafe import Markup

class RaiseExtension(Extension):
# This is our keyword(s):
tags = set(['raise'])

# See also: jinja2.parser.parse_include()
def parse(self, parser):
# the first token is the token that started the tag. In our case we
# only listen to "raise" so this will be a name token with
# "raise" as value. We get the line number so that we can give
# that line number to the nodes we insert.
lineno = next(parser.stream).lineno

# Extract the message from the template
message_node = parser.parse_expression()

return nodes.CallBlock(
self.call_method('_raise', [message_node], lineno=lineno),
[], [], [], lineno=lineno
)

def _raise(self, msg, caller):
raise TemplateRuntimeError(msg)


class RawJinjaException(Exception):

def __init__(self, message):
self._message = message

def message(self):
return self._message

class RawEnvironment:

def __init__(self, scopes, getSecret):
self.scopes = scopes
self._getSecret = getSecret

def secret(self,s):
return self._getSecret(s)


from datetime import datetime, time

class RawDate(datetime):
pass

class RawTimestamp(datetime):
pass

class RawTime(time):
pass


class RawJinja:

def __init__(self):
self._env = Environment(finalize=lambda x: self.fix(x), autoescape=False, extensions=[RaiseExtension])
self._env.filters['safe'] = self.flag_as_safe
self._env.filters['identifier'] = self.flag_as_identifier
# a default env to make sure 'environment' is predefined
self._env.globals['environment'] = RawEnvironment(None, None)

def fix(self, val):
if isinstance(val, Markup):
return val
elif isinstance(val, str):
return "'" + val.replace("'", "''") + "'"
elif isinstance(val, RawDate):
return "make_date(%d,%d,%d)" % (val.year, val.month, val.day)
elif isinstance(val, RawTime):
return "make_time(%d,%d,%d + %d / 1000000.0)" % (val.hour, val.minute, val.second, val.microsecond)
elif isinstance(val, RawTimestamp):
return "make_timestamp(%d,%d,%d,%d,%d,%f)" % (val.year, val.month, val.day, val.hour, val.minute, val.second + val.microsecond / 1000000.0)
elif isinstance(val, list):
items = ["(" + self.fix(i) + ")" for i in val]
return "(VALUES " + ",".join(items) + ")"
elif val is None or val == None:
return "NULL"
else:
return val

def flag_as_safe(self, s):
return Markup(s)

def flag_as_identifier(self, s):
return Markup('"' + s.replace('"', '""') + '"')

def _apply(self, code, args):
template = self._env.from_string(code)
return template.render(args)

def apply(self, code, scopes, secret, args):
adapter = RawJavaPythonAdapter()
d = {key: adapter._toPython(args.get(key)) for key in args.keySet()}
d['environment'] = RawEnvironment([s for s in scopes.iterator()], lambda s: secret.apply(s))
return self._apply(code, d)



def validate(self, code):
tree = self._env.parse(code)
return list(meta.find_undeclared_variables(tree))

def metadata_comments(self, code):
return [content for (line, tipe, content) in self._env.lex(code) if tipe == 'comment']


class RawJavaPythonAdapter:

def __init__(self):
import java
self.javaLocalDateClass = java.type('java.time.LocalDate')
self.javaLocalTimeClass = java.type('java.time.LocalTime')
self.javaLocalDateTimeClass = java.type('java.time.LocalDateTime')

def _toPython(self, arg):
if isinstance(arg, self.javaLocalDateClass):
return RawDate(arg.getYear(), arg.getMonthValue(), arg.getDayOfMonth())
if isinstance(arg, self.javaLocalTimeClass):
return RawTime(arg.getHour(), arg.getMinute(), arg.getSecond(), int(arg.getNano() / 1000))
if isinstance(arg, self.javaLocalDateTimeClass):
return RawTimestamp(arg.getYear(), arg.getMonthValue(), arg.getDayOfMonth(),
arg.getHour(), arg.getMinute(), arg.getSecond(), int(arg.getNano()/ 1000))
return arg


# rawJinja = RawJinja()
# code = """{% if False %}
# SELECT {{ 1 }} AS v
# {% else %}
# SELECT ** {{ 2 }} AS v
# {% endif %}"""
# result = rawJinja._env.parse(code)
# result2 = rawJinja._env.lex(code)
# result3 = rawJinja._env.preprocess(code)
# result3 = rawJinja._env.from_string(code).render()
# print(result)
21 changes: 21 additions & 0 deletions jinja-sql-client/src/main/resources/python/test_jinja.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import raw_jinja

# def test_basic():
# rawJinja = raw_jinja.RawJinja()
# result = rawJinja.test()
# print(result)
# assert result == """SELECT ** 2 AS v"""

import unittest

class TestJinja(unittest.TestCase):

def test_upper(self):
rawJinja = raw_jinja.RawJinja()
code = """{% if False %}
SELECT {{ 1 }} AS v
{% else %}
SELECT ** {{ 2 }} AS v
{% endif %}"""
result = rawJinja._apply(code, {})
unittest.TestCase.assertEqual(self, result, """SELECT ** 2 AS v""")
9 changes: 9 additions & 0 deletions jinja-sql-client/src/main/resources/reference.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
raw.client.jinja-sql {
graalpy {
executable = "/home/ld/python-venvs/graalpy23-venv/bin/python3"
home = "/home/ld/.pyenv/versions/graalpy-community-23.1.0"
core = "/home/ld/.pyenv/versions/graalpy-community-23.1.0/lib/graalpy23.1"
stdLib = "/home/ld/.pyenv/versions/graalpy-community-23.1.0/lib/python3.10"
logging = false
}
}
Loading