Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add max rows to ProgramEnvironment protocol + ExecutionSuccess response #462

Merged
merged 4 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions client/src/main/scala/raw/client/api/CompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,8 @@ trait CompilerService extends RawService {
source: String,
environment: ProgramEnvironment,
maybeDecl: Option[String],
outputStream: OutputStream
outputStream: OutputStream,
maxRows: Option[Long] = None
): ExecutionResponse

// Format a source program.
Expand Down Expand Up @@ -300,7 +301,7 @@ final case class GetProgramDescriptionSuccess(programDescription: ProgramDescrip
extends GetProgramDescriptionResponse

sealed trait ExecutionResponse
case object ExecutionSuccess extends ExecutionResponse
final case class ExecutionSuccess(complete: Boolean) extends ExecutionResponse
final case class ExecutionValidationFailure(errors: List[ErrorMessage]) extends ExecutionResponse
final case class ExecutionRuntimeFailure(error: String) extends ExecutionResponse

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec
source: String,
environment: ProgramEnvironment,
maybeDecl: Option[String],
outputStream: OutputStream
outputStream: OutputStream,
maxRows: Option[Long]
): ExecutionResponse = {
val ctx = buildTruffleContext(environment, maybeOutputStream = Some(outputStream))
ctx.initialize("python")
Expand Down Expand Up @@ -139,7 +140,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec
try {
w.write(v)
w.flush()
ExecutionSuccess
ExecutionSuccess(complete = true)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
} finally {
Expand All @@ -150,7 +151,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec
try {
w.write(v)
w.flush()
ExecutionSuccess
ExecutionSuccess(complete = true)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
} finally {
Expand All @@ -160,15 +161,15 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec
val w = new PolyglotTextWriter(outputStream)
try {
w.writeAndFlush(v)
ExecutionSuccess
ExecutionSuccess(complete = true)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
}
case Some("binary") =>
val w = new PolyglotBinaryWriter(outputStream)
try {
w.writeAndFlush(v)
ExecutionSuccess
ExecutionSuccess(complete = true)
} catch {
case ex: IOException => ExecutionRuntimeFailure(ex.getMessage)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,21 @@ class TestPythonCompilerService extends RawTestSuite with SettingsTestContext wi
test("basic execute test") { _ =>
val environment = ProgramEnvironment(user, None, Set.empty, Map("output-format" -> "json"))
val baos = new ByteArrayOutputStream()
assert(compilerService.execute("1+1", environment, None, baos) == ExecutionSuccess)
assert(compilerService.execute("1+1", environment, None, baos) == ExecutionSuccess(true))
assert(baos.toString() == "2")
}

test("basic execute test w/ decl") { _ =>
val environment = ProgramEnvironment(user, None, Set.empty, Map("output-format" -> "json"))
val baos = new ByteArrayOutputStream()
assert(compilerService.execute("def f(): return 1+1", environment, Some("f"), baos) == ExecutionSuccess)
assert(compilerService.execute("def f(): return 1+1", environment, Some("f"), baos) == ExecutionSuccess(true))
assert(baos.toString() == "2")
}

test("basic execute test w/ decl and arguments") { _ =>
val environment = ProgramEnvironment(user, Some(Array("v" -> RawInt(2))), Set.empty, Map("output-format" -> "json"))
val baos = new ByteArrayOutputStream()
assert(compilerService.execute("def f(v): return v*2", environment, Some("f"), baos) == ExecutionSuccess)
assert(compilerService.execute("def f(v): return v*2", environment, Some("f"), baos) == ExecutionSuccess(true))
assert(baos.toString() == "4")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ import java.util.Base64
import scala.annotation.tailrec
import scala.util.control.NonFatal

final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Closeable {
final class Rql2CsvWriter(os: OutputStream, lineSeparator: String, maxRows: Option[Long]) extends Closeable {

final private val gen =
try {
Expand All @@ -73,6 +73,10 @@ final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Close
final private val tryable = Rql2IsTryableTypeProperty()
final private val nullable = Rql2IsNullableTypeProperty()

private var maxRowsReached = false

def complete: Boolean = !maxRowsReached

@throws[IOException]
def write(v: Value, t: Rql2TypeWithProperties): Unit = {
if (t.props.contains(tryable)) {
Expand All @@ -97,9 +101,15 @@ final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Close
gen.setSchema(schemaBuilder.build)
gen.enable(STRICT_CHECK_FOR_QUOTING)
val iterator = v.getIterator
while (iterator.hasIteratorNextElement) {
val next = iterator.getIteratorNextElement
writeColumns(next, recordType)
var rowsWritten = 0L
while (iterator.hasIteratorNextElement && !maxRowsReached) {
if (maxRows.isDefined && rowsWritten >= maxRows.get) {
maxRowsReached = true
} else {
val next = iterator.getIteratorNextElement
writeColumns(next, recordType)
rowsWritten += 1
}
}
case Rql2ListType(recordType: Rql2RecordType, _) =>
val columnNames = recordType.atts.map(_.idn)
Expand All @@ -109,10 +119,12 @@ final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Close
gen.setSchema(schemaBuilder.build)
gen.enable(STRICT_CHECK_FOR_QUOTING)
val size = v.getArraySize
for (i <- 0L until size) {
for (i <- 0L until Math.min(size, maxRows.getOrElse(Long.MaxValue))) {
val next = v.getArrayElement(i)
writeColumns(next, recordType)
}
// Check if maxRows is reached.
maxRows.foreach(max => maxRowsReached = size > max)
case _ => throw new IOException("unsupported type")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import java.time.format.DateTimeFormatter
import java.util.Base64
import scala.util.control.NonFatal

final class Rql2JsonWriter(os: OutputStream) extends Closeable {
final class Rql2JsonWriter(os: OutputStream, maxRows: Option[Long]) extends Closeable {

final private val gen =
try {
Expand All @@ -40,108 +40,122 @@ final class Rql2JsonWriter(os: OutputStream) extends Closeable {
final private val tryable = Rql2IsTryableTypeProperty()
final private val nullable = Rql2IsNullableTypeProperty()

private var maxRowsReached = false

def complete: Boolean = !maxRowsReached

def write(v: Value, t: Rql2TypeWithProperties): Unit = {
if (t.props.contains(tryable)) {
if (v.isException) {
v.throwException()
} else {
writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties])
writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties], maxRows)
}
} else {
writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties])
writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties], maxRows)
}
}

@throws[IOException]
private def writeValue(v: Value, t: Rql2TypeWithProperties): Unit = {
private def writeValue(v: Value, t: Rql2TypeWithProperties, maxRows: Option[Long]): Unit = {
if (t.props.contains(tryable)) {
if (v.isException) {
try {
v.throwException()
} catch {
case NonFatal(ex) => gen.writeString(ex.getMessage)
}
} else writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties])
} else writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties], maxRows = maxRows)
} else if (t.props.contains(nullable)) {
if (v.isNull) gen.writeNull()
else writeValue(v, t.cloneAndRemoveProp(nullable).asInstanceOf[Rql2TypeWithProperties])
} else t match {
case _: Rql2BinaryType =>
val bytes = (0L until v.getBufferSize).map(v.readBufferByte)
gen.writeString(Base64.getEncoder.encodeToString(bytes.toArray))
case _: Rql2BoolType => gen.writeBoolean(v.asBoolean())
case _: Rql2ByteType => gen.writeNumber(v.asByte().toInt)
case _: Rql2ShortType => gen.writeNumber(v.asShort().toInt)
case _: Rql2IntType => gen.writeNumber(v.asInt())
case _: Rql2LongType => gen.writeNumber(v.asLong())
case _: Rql2FloatType => gen.writeNumber(v.asFloat())
case _: Rql2DoubleType => gen.writeNumber(v.asDouble())
case _: Rql2DecimalType => gen.writeNumber(v.asString())
case _: Rql2StringType => gen.writeString(v.asString())
case _: Rql2DateType =>
val date = v.asDate()
gen.writeString(dateFormatter.format(date))
case _: Rql2TimeType =>
val time = v.asTime()
val formatted = timeFormatter.format(time)
gen.writeString(formatted)
case _: Rql2TimestampType =>
val date = v.asDate()
val time = v.asTime()
val dateTime = date.atTime(time)
val formatted = timestampFormatter.format(dateTime)
gen.writeString(formatted)
case _: Rql2IntervalType =>
val duration = v.asDuration()
val days = duration.toDays
val hours = duration.toHoursPart
val minutes = duration.toMinutesPart
val seconds = duration.toSecondsPart
val s = new StringBuilder()
if (days > 0) s.append(s"$days days, ")
if (hours > 0) s.append(s"$hours hours, ")
if (minutes > 0) s.append(s"$minutes minutes, ")
s.append(s"$seconds seconds")
gen.writeString(s.toString())
case Rql2RecordType(atts, _) =>
gen.writeStartObject()
val keys = new java.util.Vector[String]
atts.foreach(a => keys.add(a.idn))
val distincted = RecordFieldsNaming.makeDistinct(keys)
for (i <- 0 until distincted.size()) {
val field = distincted.get(i)
gen.writeFieldName(field)
val a = v.getMember(field)
writeValue(a, atts(i).tipe.asInstanceOf[Rql2TypeWithProperties])
}
gen.writeEndObject()
case Rql2IterableType(innerType, _) =>
val iterator = v.getIterator
gen.writeStartArray()
while (iterator.hasIteratorNextElement) {
val next = iterator.getIteratorNextElement
writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties])
}
gen.writeEndArray()
case Rql2ListType(innerType, _) =>
val size = v.getArraySize
gen.writeStartArray()
for (i <- 0L until size) {
val next = v.getArrayElement(i)
writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties])
}
gen.writeEndArray()
case Rql2OrType(tipes, _) if tipes.exists(Rql2TypeUtils.getProps(_).nonEmpty) =>
// A trick to make sur inner types do not have properties
val inners = tipes.map { case inner: Rql2TypeWithProperties => Rql2TypeUtils.resetProps(inner, Set.empty) }
val orProps = tipes.flatMap { case inner: Rql2TypeWithProperties => inner.props }.toSet
writeValue(v, Rql2OrType(inners, orProps))
case Rql2OrType(tipes, _) =>
val index = v.invokeMember("getIndex").asInt()
val actualValue = v.invokeMember("getValue")
writeValue(actualValue, tipes(index).asInstanceOf[Rql2TypeWithProperties])
else writeValue(v, t.cloneAndRemoveProp(nullable).asInstanceOf[Rql2TypeWithProperties], maxRows = maxRows)
} else {
t match {
case _: Rql2BinaryType =>
val bytes = (0L until v.getBufferSize).map(v.readBufferByte)
gen.writeString(Base64.getEncoder.encodeToString(bytes.toArray))
case _: Rql2BoolType => gen.writeBoolean(v.asBoolean())
case _: Rql2ByteType => gen.writeNumber(v.asByte().toInt)
case _: Rql2ShortType => gen.writeNumber(v.asShort().toInt)
case _: Rql2IntType => gen.writeNumber(v.asInt())
case _: Rql2LongType => gen.writeNumber(v.asLong())
case _: Rql2FloatType => gen.writeNumber(v.asFloat())
case _: Rql2DoubleType => gen.writeNumber(v.asDouble())
case _: Rql2DecimalType => gen.writeNumber(v.asString())
case _: Rql2StringType => gen.writeString(v.asString())
case _: Rql2DateType =>
val date = v.asDate()
gen.writeString(dateFormatter.format(date))
case _: Rql2TimeType =>
val time = v.asTime()
val formatted = timeFormatter.format(time)
gen.writeString(formatted)
case _: Rql2TimestampType =>
val date = v.asDate()
val time = v.asTime()
val dateTime = date.atTime(time)
val formatted = timestampFormatter.format(dateTime)
gen.writeString(formatted)
case _: Rql2IntervalType =>
val duration = v.asDuration()
val days = duration.toDays
val hours = duration.toHoursPart
val minutes = duration.toMinutesPart
val seconds = duration.toSecondsPart
val s = new StringBuilder()
if (days > 0) s.append(s"$days days, ")
if (hours > 0) s.append(s"$hours hours, ")
if (minutes > 0) s.append(s"$minutes minutes, ")
s.append(s"$seconds seconds")
gen.writeString(s.toString())
case Rql2RecordType(atts, _) =>
gen.writeStartObject()
val keys = new java.util.Vector[String]
atts.foreach(a => keys.add(a.idn))
val distincted = RecordFieldsNaming.makeDistinct(keys)
for (i <- 0 until distincted.size()) {
val field = distincted.get(i)
gen.writeFieldName(field)
val a = v.getMember(field)
writeValue(a, atts(i).tipe.asInstanceOf[Rql2TypeWithProperties], maxRows = None)
}
gen.writeEndObject()
case Rql2IterableType(innerType, _) =>
var rowsWritten = 0L
val iterator = v.getIterator
gen.writeStartArray()
while (iterator.hasIteratorNextElement && !maxRowsReached) {
if (maxRows.isDefined && rowsWritten >= maxRows.get) {
maxRowsReached = true
} else {
val next = iterator.getIteratorNextElement
writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties], maxRows = None)
rowsWritten += 1
}
}
gen.writeEndArray()
case Rql2ListType(innerType, _) =>
val size = v.getArraySize
gen.writeStartArray()
for (i <- 0L until Math.min(size, maxRows.getOrElse(Long.MaxValue))) {
val next = v.getArrayElement(i)
writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties], maxRows = None)
}
gen.writeEndArray()
// Check if maxRows is reached.
maxRows.foreach(max => maxRowsReached = size > max)
case Rql2OrType(tipes, _) if tipes.exists(Rql2TypeUtils.getProps(_).nonEmpty) =>
// A trick to make sur inner types do not have properties
val inners = tipes.map { case inner: Rql2TypeWithProperties => Rql2TypeUtils.resetProps(inner, Set.empty) }
val orProps = tipes.flatMap { case inner: Rql2TypeWithProperties => inner.props }.toSet
writeValue(v, Rql2OrType(inners, orProps), maxRows = None)
case Rql2OrType(tipes, _) =>
val index = v.invokeMember("getIndex").asInt()
val actualValue = v.invokeMember("getValue")
writeValue(actualValue, tipes(index).asInstanceOf[Rql2TypeWithProperties], maxRows = None)

case _ => throw new RuntimeException("unsupported type")
case _ => throw new RuntimeException("unsupported type")
}
}
}

Expand Down
Loading