Skip to content

Commit

Permalink
Make API key scopes visible to SQL (#440)
Browse files Browse the repository at this point in the history
  • Loading branch information
bgaidioz authored Jun 12, 2024
1 parent ec7de72 commit a57eb5c
Show file tree
Hide file tree
Showing 4 changed files with 98 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,39 @@ private case class ParamLocation(
/* The parameter `parsedTree` implies parsing errors have been potential caught and reported
upfront, but we can't assume that tree is error-free. Indeed, for
*/
class NamedParametersPreparedStatement(conn: Connection, parsedTree: ParseProgramResult) extends StrictLogging {
class NamedParametersPreparedStatement(
conn: Connection,
parsedTree: ParseProgramResult,
scopes: Set[String] = Set.empty
) extends StrictLogging {

{
// create the `scopes` table if it doesn't exist. And delete its rows (in case it existed already).
val stmt = conn.prepareStatement("""
|CREATE TEMPORARY TABLE IF NOT EXISTS scopes (token VARCHAR NOT NULL);
|DELETE FROM scopes;
|""".stripMargin)
// an error is reported as a CompilerServiceException
try {
stmt.execute()
} finally {
stmt.close()
}
// insert the query scopes
val insert = conn.prepareStatement("INSERT INTO scopes (token) VALUES (?)")
// an error is reported as a CompilerServiceException
try {
// all scopes are added as batches
for (scope <- scopes) {
insert.setString(1, scope)
insert.addBatch()
}
// execute once all inserts
insert.executeBatch()
} finally {
insert.close()
}
}

private val treePositions = parsedTree.positions

Expand Down Expand Up @@ -379,8 +411,6 @@ class NamedParametersPreparedStatement(conn: Connection, parsedTree: ParseProgra
def errorPosition(p: Position): ErrorPosition = ErrorPosition(p.line, p.column)
ErrorRange(errorPosition(position), errorPosition(position1))
}
def executeQuery(): ResultSet = stmt.executeQuery()

def executeWith(parameters: Seq[(String, RawValue)]): Either[String, ResultSet] = {
val mandatoryParameters = {
for (
Expand Down Expand Up @@ -528,7 +558,7 @@ class NamedParametersPreparedStatement(conn: Connection, parsedTree: ParseProgra
}

private def asErrorString(ex: PSQLException): Option[String] = {
if (Set("42", "22").exists(ex.getSQLState.startsWith)) {
if (Set("42", "22", "0A").exists(ex.getSQLState.startsWith)) {
// syntax error / semantic error
val psqlError = Option(ex.getServerErrorMessage) // getServerErrorMessage can be null!
val error = psqlError.map(_.getMessage).getOrElse(ex.getMessage)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,10 @@ class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit
syntaxAnalyzer.parse(prog)
}

private def treeErrors(tree: ParseProgramResult, messages: Seq[String]): Seq[ErrorMessage] = {
val start = tree.positions.getStart(tree).get
private def treeErrors(program: ParseProgramResult, messages: Seq[String]): Seq[ErrorMessage] = {
val start = program.positions.getStart(program.tree).get
val startPosition = ErrorPosition(start.line, start.column)
val end = tree.positions.getFinish(tree).get
val end = program.positions.getFinish(program.tree).get
val endPosition = ErrorPosition(end.line, end.column)
messages.map(message => ErrorMessage(message, List(ErrorRange(startPosition, endPosition)), ErrorCode.SqlErrorCode))
}
Expand Down Expand Up @@ -159,7 +159,7 @@ class SqlCompilerService(maybeClassLoader: Option[ClassLoader] = None)(implicit
case Right(parsedTree) =>
val conn = connectionPool.getConnection(environment.user)
try {
val pstmt = new NamedParametersPreparedStatement(conn, parsedTree)
val pstmt = new NamedParametersPreparedStatement(conn, parsedTree, environment.scopes)
try {
pstmt.queryMetadata match {
case Right(info) => pgRowTypeToIterableType(info.outputType) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,19 @@ class TestNamedParametersStatement
// Username equals the database
private val user = InteractiveUser(Uid(database), "fdw user", "email", Seq.empty)

private var connectionPool: SqlConnectionPool = _
private var con: java.sql.Connection = _

override def beforeAll(): Unit = {
super.beforeAll()
if (password != "") {
val connectionPool = new SqlConnectionPool(credentials)
connectionPool = new SqlConnectionPool(credentials)
con = connectionPool.getConnection(user)
}
super.beforeAll()
}

override def afterAll(): Unit = {
if (con != null) con.close()
if (connectionPool != null) connectionPool.stop()
super.afterAll()
}

Expand All @@ -61,9 +62,7 @@ class TestNamedParametersStatement
val code = "SELECT :v1 as arg"

val statement = new NamedParametersPreparedStatement(con, parse(code))
statement.setParam("v1", RawString("Hello!"))
val rs = statement.executeQuery()

val rs = statement.executeWith(Seq("v1" -> RawString("Hello!"))).right.get
rs.next()
assert(rs.getString("arg") == "Hello!")
}
Expand All @@ -73,8 +72,7 @@ class TestNamedParametersStatement

val code = "SELECT :v::varchar AS greeting;"
val statement = new NamedParametersPreparedStatement(con, parse(code))
statement.setParam("v", RawString("Hello!"))
val rs = statement.executeQuery()
val rs = statement.executeWith(Seq("v" -> RawString("Hello!"))).right.get

rs.next()
assert(rs.getString("greeting") == "Hello!")
Expand All @@ -89,9 +87,7 @@ class TestNamedParametersStatement
val metadata = statement.queryMetadata.right.get
assert(metadata.parameters.keys == Set("v1", "v2"))

statement.setParam("v1", RawString("Lisbon"))
statement.setParam("v2", RawInt(1))
val rs = statement.executeQuery()
val rs = statement.executeWith(Seq("v1" -> RawString("Lisbon"), "v2" -> RawInt(1))).right.get
rs.next()
assert(rs.getString(1) == "Lisbon")
assert(rs.getInt(2) == 1)
Expand All @@ -106,8 +102,7 @@ class TestNamedParametersStatement
|*/
|SELECT :v1 as arg -- neither this one :bar """.stripMargin
val statement = new NamedParametersPreparedStatement(con, parse(code))
statement.setParam("v1", RawString("Hello!"))
val rs = statement.executeQuery()
val rs = statement.executeWith(Seq("v1" -> RawString("Hello!"))).right.get

rs.next()
assert(rs.getString("arg") == "Hello!")
Expand All @@ -120,8 +115,7 @@ class TestNamedParametersStatement
val statement = new NamedParametersPreparedStatement(con, parse(code))
val metadata = statement.queryMetadata.right.get
assert(metadata.parameters.keys == Set("bar"))
statement.setParam("bar", RawString("Hello!"))
val rs = statement.executeQuery()
val rs = statement.executeWith(Seq("bar" -> RawString("Hello!"))).right.get

rs.next()
assert(rs.getString("v1") == ":foo")
Expand All @@ -136,7 +130,7 @@ class TestNamedParametersStatement
val metadata = statement.queryMetadata
assert(metadata.isRight)
assert(metadata.right.get.parameters.isEmpty)
val rs = statement.executeQuery()
val rs = statement.executeWith(Seq.empty).right.get

rs.next()
assert(rs.getString("arg") == """[1, 2, "3", {"a": "Hello"}]""")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -998,4 +998,54 @@ class TestSqlCompilerServiceAirports
val ValidateResponse(errors) = compilerService.validate(q, asJson())
assert(errors.nonEmpty)
}

test("""scopes work""") { _ =>
assume(password != "")
val baos = new ByteArrayOutputStream()
def runWith(q: String, scopes: Set[String]): String = {
val env = ProgramEnvironment(
user,
None,
scopes,
Map("output-format" -> "json")
)
assert(compilerService.validate(q, env).messages.isEmpty)
val GetProgramDescriptionSuccess(_) = compilerService.getProgramDescription(q, env)
baos.reset()
assert(compilerService.execute(q, env, None, baos) == ExecutionSuccess)
baos.toString
}
// assert(runWith("SELECT e.airport_id FROM example.airports e", Set.empty) == """[]""")
assert(runWith("SELECT token\nFROM scopes", Set.empty) == """[]""")
assert(runWith("SELECT * FROM scopes value ORDER by value", Set.empty) == """[]""")
assert(runWith("SELECT * FROM scopes AS value ORDER by value", Set("ADMIN")) == """[{"token":"ADMIN"}]""")
assert(
runWith(
"SELECT token FROM scopes value ORDER by value",
Set("ADMIN", "SALES", "DEV")
) == """[{"token":"ADMIN"},{"token":"DEV"},{"token":"SALES"}]"""
)
assert(
runWith(
"""SELECT 'DEV' IN (SELECT * FROM scopes) AS isDev,
| 'ADMIN' IN (SELECT token FROM scopes) AS isAdmin""".stripMargin,
Set("ADMIN")
) == """[{"isdev":false,"isadmin":true}]"""
)
// demo CASE WHEN to hide a certain field
val q = """SELECT
| CASE WHEN 'DEV' IN (SELECT * FROM scopes) THEN trip_id END AS trip_id, -- "AS trip_id" to name it normally
| departure_date,
| arrival_date
|FROM example.trips
|WHERE reason = 'Holidays' AND departure_date = DATE '2016-02-27'""".stripMargin
assert(
runWith(q, Set("ADMIN"))
== """[{"trip_id":null,"departure_date":"2016-02-27","arrival_date":"2016-03-06"}]"""
)
assert(
runWith(q, Set("DEV"))
== """[{"trip_id":0,"departure_date":"2016-02-27","arrival_date":"2016-03-06"}]"""
)
}
}

0 comments on commit a57eb5c

Please sign in to comment.