Skip to content

Commit

Permalink
Add max rows to ProgramEnvironment protocol + ExecutionSuccess respon…
Browse files Browse the repository at this point in the history
…se (#462)
  • Loading branch information
miguelbranco80 authored Jul 3, 2024
1 parent 5f22150 commit d07e88b
Show file tree
Hide file tree
Showing 11 changed files with 197 additions and 146 deletions.
5 changes: 3 additions & 2 deletions client/src/main/scala/raw/client/api/CompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ trait CompilerService extends RawService {
source: String,
environment: ProgramEnvironment,
maybeDecl: Option[String],
outputStream: OutputStream
outputStream: OutputStream,
maxRows: Option[Long] = None
): ExecutionResponse

// Format a source program.
Expand Down Expand Up @@ -300,7 +301,7 @@ final case class GetProgramDescriptionSuccess(programDescription: ProgramDescrip
extends GetProgramDescriptionResponse

sealed trait ExecutionResponse
case object ExecutionSuccess extends ExecutionResponse
final case class ExecutionSuccess(complete: Boolean) extends ExecutionResponse
final case class ExecutionValidationFailure(errors: List[ErrorMessage]) extends ExecutionResponse
final case class ExecutionRuntimeFailure(error: String) extends ExecutionResponse

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec
source: String,
environment: ProgramEnvironment,
maybeDecl: Option[String],
outputStream: OutputStream
outputStream: OutputStream,
maxRows: Option[Long]
): ExecutionResponse = {
val ctx = buildTruffleContext(environment, maybeOutputStream = Some(outputStream))
ctx.initialize("python")
Expand Down Expand Up @@ -139,7 +140,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec
try {
w.write(v)
w.flush()
ExecutionSuccess
ExecutionSuccess(complete = true)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
} finally {
Expand All @@ -150,7 +151,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec
try {
w.write(v)
w.flush()
ExecutionSuccess
ExecutionSuccess(complete = true)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
} finally {
Expand All @@ -160,15 +161,15 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec
val w = new PolyglotTextWriter(outputStream)
try {
w.writeAndFlush(v)
ExecutionSuccess
ExecutionSuccess(complete = true)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
}
case Some("binary") =>
val w = new PolyglotBinaryWriter(outputStream)
try {
w.writeAndFlush(v)
ExecutionSuccess
ExecutionSuccess(complete = true)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ class TestPythonCompilerService extends RawTestSuite with SettingsTestContext wi
test("basic execute test") { _ =>
val environment = ProgramEnvironment(user, None, Set.empty, Map("output-format" -> "json"))
val baos = new ByteArrayOutputStream()
assert(compilerService.execute("1+1", environment, None, baos) == ExecutionSuccess)
assert(compilerService.execute("1+1", environment, None, baos) == ExecutionSuccess(true))
assert(baos.toString() == "2")
}

test("basic execute test w/ decl") { _ =>
val environment = ProgramEnvironment(user, None, Set.empty, Map("output-format" -> "json"))
val baos = new ByteArrayOutputStream()
assert(compilerService.execute("def f(): return 1+1", environment, Some("f"), baos) == ExecutionSuccess)
assert(compilerService.execute("def f(): return 1+1", environment, Some("f"), baos) == ExecutionSuccess(true))
assert(baos.toString() == "2")
}

test("basic execute test w/ decl and arguments") { _ =>
val environment = ProgramEnvironment(user, Some(Array("v" -> RawInt(2))), Set.empty, Map("output-format" -> "json"))
val baos = new ByteArrayOutputStream()
assert(compilerService.execute("def f(v): return v*2", environment, Some("f"), baos) == ExecutionSuccess)
assert(compilerService.execute("def f(v): return v*2", environment, Some("f"), baos) == ExecutionSuccess(true))
assert(baos.toString() == "4")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import java.util.Base64
import scala.annotation.tailrec
import scala.util.control.NonFatal

final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Closeable {
final class Rql2CsvWriter(os: OutputStream, lineSeparator: String, maxRows: Option[Long]) extends Closeable {

final private val gen =
try {
Expand All @@ -73,6 +73,10 @@ final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Close
final private val tryable = Rql2IsTryableTypeProperty()
final private val nullable = Rql2IsNullableTypeProperty()

private var maxRowsReached = false

def complete: Boolean = !maxRowsReached

@throws[IOException]
def write(v: Value, t: Rql2TypeWithProperties): Unit = {
if (t.props.contains(tryable)) {
Expand All @@ -97,9 +101,15 @@ final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Close
gen.setSchema(schemaBuilder.build)
gen.enable(STRICT_CHECK_FOR_QUOTING)
val iterator = v.getIterator
while (iterator.hasIteratorNextElement) {
val next = iterator.getIteratorNextElement
writeColumns(next, recordType)
var rowsWritten = 0L
while (iterator.hasIteratorNextElement && !maxRowsReached) {
if (maxRows.isDefined && rowsWritten >= maxRows.get) {
maxRowsReached = true
} else {
val next = iterator.getIteratorNextElement
writeColumns(next, recordType)
rowsWritten += 1
}
}
case Rql2ListType(recordType: Rql2RecordType, _) =>
val columnNames = recordType.atts.map(_.idn)
Expand All @@ -109,10 +119,12 @@ final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Close
gen.setSchema(schemaBuilder.build)
gen.enable(STRICT_CHECK_FOR_QUOTING)
val size = v.getArraySize
for (i <- 0L until size) {
for (i <- 0L until Math.min(size, maxRows.getOrElse(Long.MaxValue))) {
val next = v.getArrayElement(i)
writeColumns(next, recordType)
}
// Check if maxRows is reached.
maxRows.foreach(max => maxRowsReached = size > max)
case _ => throw new IOException("unsupported type")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.time.format.DateTimeFormatter
import java.util.Base64
import scala.util.control.NonFatal

final class Rql2JsonWriter(os: OutputStream) extends Closeable {
final class Rql2JsonWriter(os: OutputStream, maxRows: Option[Long]) extends Closeable {

final private val gen =
try {
Expand All @@ -40,108 +40,122 @@ final class Rql2JsonWriter(os: OutputStream) extends Closeable {
final private val tryable = Rql2IsTryableTypeProperty()
final private val nullable = Rql2IsNullableTypeProperty()

private var maxRowsReached = false

def complete: Boolean = !maxRowsReached

def write(v: Value, t: Rql2TypeWithProperties): Unit = {
if (t.props.contains(tryable)) {
if (v.isException) {
v.throwException()
} else {
writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties])
writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties], maxRows)
}
} else {
writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties])
writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties], maxRows)
}
}

@throws[IOException]
private def writeValue(v: Value, t: Rql2TypeWithProperties): Unit = {
private def writeValue(v: Value, t: Rql2TypeWithProperties, maxRows: Option[Long]): Unit = {
if (t.props.contains(tryable)) {
if (v.isException) {
try {
v.throwException()
} catch {
case NonFatal(ex) => gen.writeString(ex.getMessage)
}
} else writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties])
} else writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties], maxRows = maxRows)
} else if (t.props.contains(nullable)) {
if (v.isNull) gen.writeNull()
else writeValue(v, t.cloneAndRemoveProp(nullable).asInstanceOf[Rql2TypeWithProperties])
} else t match {
case _: Rql2BinaryType =>
val bytes = (0L until v.getBufferSize).map(v.readBufferByte)
gen.writeString(Base64.getEncoder.encodeToString(bytes.toArray))
case _: Rql2BoolType => gen.writeBoolean(v.asBoolean())
case _: Rql2ByteType => gen.writeNumber(v.asByte().toInt)
case _: Rql2ShortType => gen.writeNumber(v.asShort().toInt)
case _: Rql2IntType => gen.writeNumber(v.asInt())
case _: Rql2LongType => gen.writeNumber(v.asLong())
case _: Rql2FloatType => gen.writeNumber(v.asFloat())
case _: Rql2DoubleType => gen.writeNumber(v.asDouble())
case _: Rql2DecimalType => gen.writeNumber(v.asString())
case _: Rql2StringType => gen.writeString(v.asString())
case _: Rql2DateType =>
val date = v.asDate()
gen.writeString(dateFormatter.format(date))
case _: Rql2TimeType =>
val time = v.asTime()
val formatted = timeFormatter.format(time)
gen.writeString(formatted)
case _: Rql2TimestampType =>
val date = v.asDate()
val time = v.asTime()
val dateTime = date.atTime(time)
val formatted = timestampFormatter.format(dateTime)
gen.writeString(formatted)
case _: Rql2IntervalType =>
val duration = v.asDuration()
val days = duration.toDays
val hours = duration.toHoursPart
val minutes = duration.toMinutesPart
val seconds = duration.toSecondsPart
val s = new StringBuilder()
if (days > 0) s.append(s"$days days, ")
if (hours > 0) s.append(s"$hours hours, ")
if (minutes > 0) s.append(s"$minutes minutes, ")
s.append(s"$seconds seconds")
gen.writeString(s.toString())
case Rql2RecordType(atts, _) =>
gen.writeStartObject()
val keys = new java.util.Vector[String]
atts.foreach(a => keys.add(a.idn))
val distincted = RecordFieldsNaming.makeDistinct(keys)
for (i <- 0 until distincted.size()) {
val field = distincted.get(i)
gen.writeFieldName(field)
val a = v.getMember(field)
writeValue(a, atts(i).tipe.asInstanceOf[Rql2TypeWithProperties])
}
gen.writeEndObject()
case Rql2IterableType(innerType, _) =>
val iterator = v.getIterator
gen.writeStartArray()
while (iterator.hasIteratorNextElement) {
val next = iterator.getIteratorNextElement
writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties])
}
gen.writeEndArray()
case Rql2ListType(innerType, _) =>
val size = v.getArraySize
gen.writeStartArray()
for (i <- 0L until size) {
val next = v.getArrayElement(i)
writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties])
}
gen.writeEndArray()
case Rql2OrType(tipes, _) if tipes.exists(Rql2TypeUtils.getProps(_).nonEmpty) =>
// A trick to make sur inner types do not have properties
val inners = tipes.map { case inner: Rql2TypeWithProperties => Rql2TypeUtils.resetProps(inner, Set.empty) }
val orProps = tipes.flatMap { case inner: Rql2TypeWithProperties => inner.props }.toSet
writeValue(v, Rql2OrType(inners, orProps))
case Rql2OrType(tipes, _) =>
val index = v.invokeMember("getIndex").asInt()
val actualValue = v.invokeMember("getValue")
writeValue(actualValue, tipes(index).asInstanceOf[Rql2TypeWithProperties])
else writeValue(v, t.cloneAndRemoveProp(nullable).asInstanceOf[Rql2TypeWithProperties], maxRows = maxRows)
} else {
t match {
case _: Rql2BinaryType =>
val bytes = (0L until v.getBufferSize).map(v.readBufferByte)
gen.writeString(Base64.getEncoder.encodeToString(bytes.toArray))
case _: Rql2BoolType => gen.writeBoolean(v.asBoolean())
case _: Rql2ByteType => gen.writeNumber(v.asByte().toInt)
case _: Rql2ShortType => gen.writeNumber(v.asShort().toInt)
case _: Rql2IntType => gen.writeNumber(v.asInt())
case _: Rql2LongType => gen.writeNumber(v.asLong())
case _: Rql2FloatType => gen.writeNumber(v.asFloat())
case _: Rql2DoubleType => gen.writeNumber(v.asDouble())
case _: Rql2DecimalType => gen.writeNumber(v.asString())
case _: Rql2StringType => gen.writeString(v.asString())
case _: Rql2DateType =>
val date = v.asDate()
gen.writeString(dateFormatter.format(date))
case _: Rql2TimeType =>
val time = v.asTime()
val formatted = timeFormatter.format(time)
gen.writeString(formatted)
case _: Rql2TimestampType =>
val date = v.asDate()
val time = v.asTime()
val dateTime = date.atTime(time)
val formatted = timestampFormatter.format(dateTime)
gen.writeString(formatted)
case _: Rql2IntervalType =>
val duration = v.asDuration()
val days = duration.toDays
val hours = duration.toHoursPart
val minutes = duration.toMinutesPart
val seconds = duration.toSecondsPart
val s = new StringBuilder()
if (days > 0) s.append(s"$days days, ")
if (hours > 0) s.append(s"$hours hours, ")
if (minutes > 0) s.append(s"$minutes minutes, ")
s.append(s"$seconds seconds")
gen.writeString(s.toString())
case Rql2RecordType(atts, _) =>
gen.writeStartObject()
val keys = new java.util.Vector[String]
atts.foreach(a => keys.add(a.idn))
val distincted = RecordFieldsNaming.makeDistinct(keys)
for (i <- 0 until distincted.size()) {
val field = distincted.get(i)
gen.writeFieldName(field)
val a = v.getMember(field)
writeValue(a, atts(i).tipe.asInstanceOf[Rql2TypeWithProperties], maxRows = None)
}
gen.writeEndObject()
case Rql2IterableType(innerType, _) =>
var rowsWritten = 0L
val iterator = v.getIterator
gen.writeStartArray()
while (iterator.hasIteratorNextElement && !maxRowsReached) {
if (maxRows.isDefined && rowsWritten >= maxRows.get) {
maxRowsReached = true
} else {
val next = iterator.getIteratorNextElement
writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties], maxRows = None)
rowsWritten += 1
}
}
gen.writeEndArray()
case Rql2ListType(innerType, _) =>
val size = v.getArraySize
gen.writeStartArray()
for (i <- 0L until Math.min(size, maxRows.getOrElse(Long.MaxValue))) {
val next = v.getArrayElement(i)
writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties], maxRows = None)
}
gen.writeEndArray()
// Check if maxRows is reached.
maxRows.foreach(max => maxRowsReached = size > max)
case Rql2OrType(tipes, _) if tipes.exists(Rql2TypeUtils.getProps(_).nonEmpty) =>
// A trick to make sur inner types do not have properties
val inners = tipes.map { case inner: Rql2TypeWithProperties => Rql2TypeUtils.resetProps(inner, Set.empty) }
val orProps = tipes.flatMap { case inner: Rql2TypeWithProperties => inner.props }.toSet
writeValue(v, Rql2OrType(inners, orProps), maxRows = None)
case Rql2OrType(tipes, _) =>
val index = v.invokeMember("getIndex").asInt()
val actualValue = v.invokeMember("getValue")
writeValue(actualValue, tipes(index).asInstanceOf[Rql2TypeWithProperties], maxRows = None)

case _ => throw new RuntimeException("unsupported type")
case _ => throw new RuntimeException("unsupported type")
}
}
}

Expand Down
Loading

0 comments on commit d07e88b

Please sign in to comment.