diff --git a/client/src/main/scala/raw/client/api/CompilerService.scala b/client/src/main/scala/raw/client/api/CompilerService.scala index 21aedd3b0..cd505203c 100644 --- a/client/src/main/scala/raw/client/api/CompilerService.scala +++ b/client/src/main/scala/raw/client/api/CompilerService.scala @@ -241,7 +241,8 @@ trait CompilerService extends RawService { source: String, environment: ProgramEnvironment, maybeDecl: Option[String], - outputStream: OutputStream + outputStream: OutputStream, + maxRows: Option[Long] = None ): ExecutionResponse // Format a source program. @@ -300,7 +301,7 @@ final case class GetProgramDescriptionSuccess(programDescription: ProgramDescrip extends GetProgramDescriptionResponse sealed trait ExecutionResponse -case object ExecutionSuccess extends ExecutionResponse +final case class ExecutionSuccess(complete: Boolean) extends ExecutionResponse final case class ExecutionValidationFailure(errors: List[ErrorMessage]) extends ExecutionResponse final case class ExecutionRuntimeFailure(error: String) extends ExecutionResponse diff --git a/python-client/src/main/scala/raw/client/python/PythonCompilerService.scala b/python-client/src/main/scala/raw/client/python/PythonCompilerService.scala index c4f279da6..1dab8f167 100644 --- a/python-client/src/main/scala/raw/client/python/PythonCompilerService.scala +++ b/python-client/src/main/scala/raw/client/python/PythonCompilerService.scala @@ -104,7 +104,8 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec source: String, environment: ProgramEnvironment, maybeDecl: Option[String], - outputStream: OutputStream + outputStream: OutputStream, + maxRows: Option[Long] ): ExecutionResponse = { val ctx = buildTruffleContext(environment, maybeOutputStream = Some(outputStream)) ctx.initialize("python") @@ -139,7 +140,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec try { w.write(v) w.flush() - ExecutionSuccess + ExecutionSuccess(complete = true) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } finally { @@ -150,7 +151,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec try { w.write(v) w.flush() - ExecutionSuccess + ExecutionSuccess(complete = true) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } finally { @@ -160,7 +161,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec val w = new PolyglotTextWriter(outputStream) try { w.writeAndFlush(v) - ExecutionSuccess + ExecutionSuccess(complete = true) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } @@ -168,7 +169,7 @@ class PythonCompilerService(engineDefinition: (Engine, Boolean))(implicit protec val w = new PolyglotBinaryWriter(outputStream) try { w.writeAndFlush(v) - ExecutionSuccess + ExecutionSuccess(complete = true) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } diff --git a/python-client/src/test/scala/raw/client/python/TestPythonCompilerService.scala b/python-client/src/test/scala/raw/client/python/TestPythonCompilerService.scala index b778dad08..f38f8ecc0 100644 --- a/python-client/src/test/scala/raw/client/python/TestPythonCompilerService.scala +++ b/python-client/src/test/scala/raw/client/python/TestPythonCompilerService.scala @@ -40,21 +40,21 @@ class TestPythonCompilerService extends RawTestSuite with SettingsTestContext wi test("basic execute test") { _ => val environment = ProgramEnvironment(user, None, Set.empty, Map("output-format" -> "json")) val baos = new ByteArrayOutputStream() - assert(compilerService.execute("1+1", environment, None, baos) == ExecutionSuccess) + assert(compilerService.execute("1+1", environment, None, baos) == ExecutionSuccess(true)) assert(baos.toString() == "2") } test("basic execute test w/ decl") { _ => val environment = ProgramEnvironment(user, None, Set.empty, Map("output-format" -> "json")) val baos = new ByteArrayOutputStream() - assert(compilerService.execute("def f(): return 1+1", environment, Some("f"), baos) == ExecutionSuccess) + assert(compilerService.execute("def f(): return 1+1", environment, Some("f"), baos) == ExecutionSuccess(true)) assert(baos.toString() == "2") } test("basic execute test w/ decl and arguments") { _ => val environment = ProgramEnvironment(user, Some(Array("v" -> RawInt(2))), Set.empty, Map("output-format" -> "json")) val baos = new ByteArrayOutputStream() - assert(compilerService.execute("def f(v): return v*2", environment, Some("f"), baos) == ExecutionSuccess) + assert(compilerService.execute("def f(v): return v*2", environment, Some("f"), baos) == ExecutionSuccess(true)) assert(baos.toString() == "4") } diff --git a/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2CsvWriter.scala b/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2CsvWriter.scala index aa5e0d7fb..ccfd4c120 100644 --- a/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2CsvWriter.scala +++ b/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2CsvWriter.scala @@ -47,7 +47,7 @@ import java.util.Base64 import scala.annotation.tailrec import scala.util.control.NonFatal -final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Closeable { +final class Rql2CsvWriter(os: OutputStream, lineSeparator: String, maxRows: Option[Long]) extends Closeable { final private val gen = try { @@ -73,6 +73,10 @@ final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Close final private val tryable = Rql2IsTryableTypeProperty() final private val nullable = Rql2IsNullableTypeProperty() + private var maxRowsReached = false + + def complete: Boolean = !maxRowsReached + @throws[IOException] def write(v: Value, t: Rql2TypeWithProperties): Unit = { if (t.props.contains(tryable)) { @@ -97,9 +101,15 @@ final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Close gen.setSchema(schemaBuilder.build) gen.enable(STRICT_CHECK_FOR_QUOTING) val iterator = v.getIterator - while (iterator.hasIteratorNextElement) { - val next = iterator.getIteratorNextElement - writeColumns(next, recordType) + var rowsWritten = 0L + while (iterator.hasIteratorNextElement && !maxRowsReached) { + if (maxRows.isDefined && rowsWritten >= maxRows.get) { + maxRowsReached = true + } else { + val next = iterator.getIteratorNextElement + writeColumns(next, recordType) + rowsWritten += 1 + } } case Rql2ListType(recordType: Rql2RecordType, _) => val columnNames = recordType.atts.map(_.idn) @@ -109,10 +119,12 @@ final class Rql2CsvWriter(os: OutputStream, lineSeparator: String) extends Close gen.setSchema(schemaBuilder.build) gen.enable(STRICT_CHECK_FOR_QUOTING) val size = v.getArraySize - for (i <- 0L until size) { + for (i <- 0L until Math.min(size, maxRows.getOrElse(Long.MaxValue))) { val next = v.getArrayElement(i) writeColumns(next, recordType) } + // Check if maxRows is reached. + maxRows.foreach(max => maxRowsReached = size > max) case _ => throw new IOException("unsupported type") } } diff --git a/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2JsonWriter.scala b/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2JsonWriter.scala index f4601fb22..38651bf47 100644 --- a/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2JsonWriter.scala +++ b/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2JsonWriter.scala @@ -23,7 +23,7 @@ import java.time.format.DateTimeFormatter import java.util.Base64 import scala.util.control.NonFatal -final class Rql2JsonWriter(os: OutputStream) extends Closeable { +final class Rql2JsonWriter(os: OutputStream, maxRows: Option[Long]) extends Closeable { final private val gen = try { @@ -40,20 +40,24 @@ final class Rql2JsonWriter(os: OutputStream) extends Closeable { final private val tryable = Rql2IsTryableTypeProperty() final private val nullable = Rql2IsNullableTypeProperty() + private var maxRowsReached = false + + def complete: Boolean = !maxRowsReached + def write(v: Value, t: Rql2TypeWithProperties): Unit = { if (t.props.contains(tryable)) { if (v.isException) { v.throwException() } else { - writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties]) + writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties], maxRows) } } else { - writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties]) + writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties], maxRows) } } @throws[IOException] - private def writeValue(v: Value, t: Rql2TypeWithProperties): Unit = { + private def writeValue(v: Value, t: Rql2TypeWithProperties, maxRows: Option[Long]): Unit = { if (t.props.contains(tryable)) { if (v.isException) { try { @@ -61,87 +65,97 @@ final class Rql2JsonWriter(os: OutputStream) extends Closeable { } catch { case NonFatal(ex) => gen.writeString(ex.getMessage) } - } else writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties]) + } else writeValue(v, t.cloneAndRemoveProp(tryable).asInstanceOf[Rql2TypeWithProperties], maxRows = maxRows) } else if (t.props.contains(nullable)) { if (v.isNull) gen.writeNull() - else writeValue(v, t.cloneAndRemoveProp(nullable).asInstanceOf[Rql2TypeWithProperties]) - } else t match { - case _: Rql2BinaryType => - val bytes = (0L until v.getBufferSize).map(v.readBufferByte) - gen.writeString(Base64.getEncoder.encodeToString(bytes.toArray)) - case _: Rql2BoolType => gen.writeBoolean(v.asBoolean()) - case _: Rql2ByteType => gen.writeNumber(v.asByte().toInt) - case _: Rql2ShortType => gen.writeNumber(v.asShort().toInt) - case _: Rql2IntType => gen.writeNumber(v.asInt()) - case _: Rql2LongType => gen.writeNumber(v.asLong()) - case _: Rql2FloatType => gen.writeNumber(v.asFloat()) - case _: Rql2DoubleType => gen.writeNumber(v.asDouble()) - case _: Rql2DecimalType => gen.writeNumber(v.asString()) - case _: Rql2StringType => gen.writeString(v.asString()) - case _: Rql2DateType => - val date = v.asDate() - gen.writeString(dateFormatter.format(date)) - case _: Rql2TimeType => - val time = v.asTime() - val formatted = timeFormatter.format(time) - gen.writeString(formatted) - case _: Rql2TimestampType => - val date = v.asDate() - val time = v.asTime() - val dateTime = date.atTime(time) - val formatted = timestampFormatter.format(dateTime) - gen.writeString(formatted) - case _: Rql2IntervalType => - val duration = v.asDuration() - val days = duration.toDays - val hours = duration.toHoursPart - val minutes = duration.toMinutesPart - val seconds = duration.toSecondsPart - val s = new StringBuilder() - if (days > 0) s.append(s"$days days, ") - if (hours > 0) s.append(s"$hours hours, ") - if (minutes > 0) s.append(s"$minutes minutes, ") - s.append(s"$seconds seconds") - gen.writeString(s.toString()) - case Rql2RecordType(atts, _) => - gen.writeStartObject() - val keys = new java.util.Vector[String] - atts.foreach(a => keys.add(a.idn)) - val distincted = RecordFieldsNaming.makeDistinct(keys) - for (i <- 0 until distincted.size()) { - val field = distincted.get(i) - gen.writeFieldName(field) - val a = v.getMember(field) - writeValue(a, atts(i).tipe.asInstanceOf[Rql2TypeWithProperties]) - } - gen.writeEndObject() - case Rql2IterableType(innerType, _) => - val iterator = v.getIterator - gen.writeStartArray() - while (iterator.hasIteratorNextElement) { - val next = iterator.getIteratorNextElement - writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties]) - } - gen.writeEndArray() - case Rql2ListType(innerType, _) => - val size = v.getArraySize - gen.writeStartArray() - for (i <- 0L until size) { - val next = v.getArrayElement(i) - writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties]) - } - gen.writeEndArray() - case Rql2OrType(tipes, _) if tipes.exists(Rql2TypeUtils.getProps(_).nonEmpty) => - // A trick to make sur inner types do not have properties - val inners = tipes.map { case inner: Rql2TypeWithProperties => Rql2TypeUtils.resetProps(inner, Set.empty) } - val orProps = tipes.flatMap { case inner: Rql2TypeWithProperties => inner.props }.toSet - writeValue(v, Rql2OrType(inners, orProps)) - case Rql2OrType(tipes, _) => - val index = v.invokeMember("getIndex").asInt() - val actualValue = v.invokeMember("getValue") - writeValue(actualValue, tipes(index).asInstanceOf[Rql2TypeWithProperties]) + else writeValue(v, t.cloneAndRemoveProp(nullable).asInstanceOf[Rql2TypeWithProperties], maxRows = maxRows) + } else { + t match { + case _: Rql2BinaryType => + val bytes = (0L until v.getBufferSize).map(v.readBufferByte) + gen.writeString(Base64.getEncoder.encodeToString(bytes.toArray)) + case _: Rql2BoolType => gen.writeBoolean(v.asBoolean()) + case _: Rql2ByteType => gen.writeNumber(v.asByte().toInt) + case _: Rql2ShortType => gen.writeNumber(v.asShort().toInt) + case _: Rql2IntType => gen.writeNumber(v.asInt()) + case _: Rql2LongType => gen.writeNumber(v.asLong()) + case _: Rql2FloatType => gen.writeNumber(v.asFloat()) + case _: Rql2DoubleType => gen.writeNumber(v.asDouble()) + case _: Rql2DecimalType => gen.writeNumber(v.asString()) + case _: Rql2StringType => gen.writeString(v.asString()) + case _: Rql2DateType => + val date = v.asDate() + gen.writeString(dateFormatter.format(date)) + case _: Rql2TimeType => + val time = v.asTime() + val formatted = timeFormatter.format(time) + gen.writeString(formatted) + case _: Rql2TimestampType => + val date = v.asDate() + val time = v.asTime() + val dateTime = date.atTime(time) + val formatted = timestampFormatter.format(dateTime) + gen.writeString(formatted) + case _: Rql2IntervalType => + val duration = v.asDuration() + val days = duration.toDays + val hours = duration.toHoursPart + val minutes = duration.toMinutesPart + val seconds = duration.toSecondsPart + val s = new StringBuilder() + if (days > 0) s.append(s"$days days, ") + if (hours > 0) s.append(s"$hours hours, ") + if (minutes > 0) s.append(s"$minutes minutes, ") + s.append(s"$seconds seconds") + gen.writeString(s.toString()) + case Rql2RecordType(atts, _) => + gen.writeStartObject() + val keys = new java.util.Vector[String] + atts.foreach(a => keys.add(a.idn)) + val distincted = RecordFieldsNaming.makeDistinct(keys) + for (i <- 0 until distincted.size()) { + val field = distincted.get(i) + gen.writeFieldName(field) + val a = v.getMember(field) + writeValue(a, atts(i).tipe.asInstanceOf[Rql2TypeWithProperties], maxRows = None) + } + gen.writeEndObject() + case Rql2IterableType(innerType, _) => + var rowsWritten = 0L + val iterator = v.getIterator + gen.writeStartArray() + while (iterator.hasIteratorNextElement && !maxRowsReached) { + if (maxRows.isDefined && rowsWritten >= maxRows.get) { + maxRowsReached = true + } else { + val next = iterator.getIteratorNextElement + writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties], maxRows = None) + rowsWritten += 1 + } + } + gen.writeEndArray() + case Rql2ListType(innerType, _) => + val size = v.getArraySize + gen.writeStartArray() + for (i <- 0L until Math.min(size, maxRows.getOrElse(Long.MaxValue))) { + val next = v.getArrayElement(i) + writeValue(next, innerType.asInstanceOf[Rql2TypeWithProperties], maxRows = None) + } + gen.writeEndArray() + // Check if maxRows is reached. + maxRows.foreach(max => maxRowsReached = size > max) + case Rql2OrType(tipes, _) if tipes.exists(Rql2TypeUtils.getProps(_).nonEmpty) => + // A trick to make sur inner types do not have properties + val inners = tipes.map { case inner: Rql2TypeWithProperties => Rql2TypeUtils.resetProps(inner, Set.empty) } + val orProps = tipes.flatMap { case inner: Rql2TypeWithProperties => inner.props }.toSet + writeValue(v, Rql2OrType(inners, orProps), maxRows = None) + case Rql2OrType(tipes, _) => + val index = v.invokeMember("getIndex").asInt() + val actualValue = v.invokeMember("getValue") + writeValue(actualValue, tipes(index).asInstanceOf[Rql2TypeWithProperties], maxRows = None) - case _ => throw new RuntimeException("unsupported type") + case _ => throw new RuntimeException("unsupported type") + } } } diff --git a/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2TruffleCompilerService.scala b/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2TruffleCompilerService.scala index c660d5d00..bd14181cf 100644 --- a/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2TruffleCompilerService.scala +++ b/snapi-client/src/main/scala/raw/client/rql2/truffle/Rql2TruffleCompilerService.scala @@ -197,7 +197,8 @@ class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean))(implicit p source: String, environment: ProgramEnvironment, maybeDecl: Option[String], - outputStream: OutputStream + outputStream: OutputStream, + maxRows: Option[Long] ): ExecutionResponse = { val ctx = buildTruffleContext(environment, maybeOutputStream = Some(outputStream)) ctx.initialize("rql") @@ -279,11 +280,11 @@ class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean))(implicit p case _ => programContext.settings.config.getBoolean("raw.compiler.windows-line-ending") } val lineSeparator = if (windowsLineEnding) "\r\n" else "\n" - val w = new Rql2CsvWriter(outputStream, lineSeparator) + val w = new Rql2CsvWriter(outputStream, lineSeparator, maxRows) try { w.write(v, tipe.asInstanceOf[Rql2TypeWithProperties]) w.flush() - ExecutionSuccess + ExecutionSuccess(w.complete) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } finally { @@ -293,11 +294,11 @@ class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean))(implicit p if (!JsonPackage.outputWriteSupport(tipe)) { return ExecutionRuntimeFailure("unsupported type") } - val w = new Rql2JsonWriter(outputStream) + val w = new Rql2JsonWriter(outputStream, maxRows) try { w.write(v, tipe.asInstanceOf[Rql2TypeWithProperties]) w.flush() - ExecutionSuccess + ExecutionSuccess(w.complete) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } finally { @@ -310,7 +311,7 @@ class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean))(implicit p val w = new PolyglotTextWriter(outputStream) try { w.writeAndFlush(v) - ExecutionSuccess + ExecutionSuccess(complete = true) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } @@ -321,7 +322,7 @@ class Rql2TruffleCompilerService(engineDefinition: (Engine, Boolean))(implicit p val w = new PolyglotBinaryWriter(outputStream) try { w.writeAndFlush(v) - ExecutionSuccess + ExecutionSuccess(complete = true) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } diff --git a/snapi-client/src/test/scala/raw/compiler/rql2/tests/Rql2CompilerTestContext.scala b/snapi-client/src/test/scala/raw/compiler/rql2/tests/Rql2CompilerTestContext.scala index c2d3a5ecb..423df52e2 100644 --- a/snapi-client/src/test/scala/raw/compiler/rql2/tests/Rql2CompilerTestContext.scala +++ b/snapi-client/src/test/scala/raw/compiler/rql2/tests/Rql2CompilerTestContext.scala @@ -640,7 +640,7 @@ trait Rql2CompilerTestContext ) match { case ExecutionValidationFailure(errs) => Left(errs.map(err => err.toString).mkString(",")) case ExecutionRuntimeFailure(err) => Left(err) - case ExecutionSuccess => Right(Path.of(outputStream.toString)) + case ExecutionSuccess(_) => Right(Path.of(outputStream.toString)) } } finally { outputStream.close() @@ -689,7 +689,7 @@ trait Rql2CompilerTestContext compilerService.execute(query, getQueryEnvironment(maybeArgs, scopes, options), maybeDecl, outputStream) match { case ExecutionValidationFailure(errs) => Left(errs.map(err => err.toString).mkString(",")) case ExecutionRuntimeFailure(err) => Left(err) - case ExecutionSuccess => Right(path) + case ExecutionSuccess(_) => Right(path) } } finally { outputStream.close() diff --git a/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala b/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala index a1cfdea64..699198b7f 100644 --- a/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala +++ b/sql-client/src/main/scala/raw/client/sql/SqlCompilerService.scala @@ -150,7 +150,8 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends source: String, environment: ProgramEnvironment, maybeDecl: Option[String], - outputStream: OutputStream + outputStream: OutputStream, + maxRows: Option[Long] ): ExecutionResponse = { try { logger.debug(s"Executing: $source") @@ -166,7 +167,7 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends case Right(tipe) => val arguments = environment.maybeArguments.getOrElse(Array.empty) pstmt.executeWith(arguments) match { - case Right(r) => render(environment, tipe, r, outputStream) + case Right(r) => render(environment, tipe, r, outputStream, maxRows) case Left(error) => ExecutionRuntimeFailure(error) } case Left(errors) => ExecutionRuntimeFailure(errors.mkString(", ")) @@ -192,7 +193,8 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends environment: ProgramEnvironment, tipe: RawType, v: ResultSet, - outputStream: OutputStream + outputStream: OutputStream, + maxRows: Option[Long] ): ExecutionResponse = { environment.options .get("output-format") @@ -206,23 +208,23 @@ class SqlCompilerService()(implicit protected val settings: RawSettings) extends case _ => false //settings.config.getBoolean("raw.compiler.windows-line-ending") } val lineSeparator = if (windowsLineEnding) "\r\n" else "\n" - val csvWriter = new TypedResultSetCsvWriter(outputStream, lineSeparator) + val w = new TypedResultSetCsvWriter(outputStream, lineSeparator, maxRows) try { - csvWriter.write(v, tipe) - ExecutionSuccess + w.write(v, tipe) + ExecutionSuccess(w.complete) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } finally { - RawUtils.withSuppressNonFatalException(csvWriter.close()) + RawUtils.withSuppressNonFatalException(w.close()) } case Some("json") => if (!TypedResultSetJsonWriter.outputWriteSupport(tipe)) { ExecutionRuntimeFailure("unsupported type") } - val w = new TypedResultSetJsonWriter(outputStream) + val w = new TypedResultSetJsonWriter(outputStream, maxRows) try { w.write(v, tipe) - ExecutionSuccess + ExecutionSuccess(w.complete) } catch { case ex: IOException => ExecutionRuntimeFailure(ex.getMessage) } finally { diff --git a/sql-client/src/main/scala/raw/client/sql/writers/TypedResultSetCsvWriter.scala b/sql-client/src/main/scala/raw/client/sql/writers/TypedResultSetCsvWriter.scala index d9d40b67a..b9d40efaf 100644 --- a/sql-client/src/main/scala/raw/client/sql/writers/TypedResultSetCsvWriter.scala +++ b/sql-client/src/main/scala/raw/client/sql/writers/TypedResultSetCsvWriter.scala @@ -33,7 +33,7 @@ object TypedResultSetCsvWriter { } -class TypedResultSetCsvWriter(os: OutputStream, lineSeparator: String) { +class TypedResultSetCsvWriter(os: OutputStream, lineSeparator: String, maxRows: Option[Long]) { final private val gen = try { @@ -57,6 +57,10 @@ class TypedResultSetCsvWriter(os: OutputStream, lineSeparator: String) { final private val timestampFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS") final private val timestampFormatterNoMs = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss") + private var maxRowsReached = false + + def complete: Boolean = !maxRowsReached + @throws[IOException] def write(resultSet: ResultSet, t: RawType): Unit = { val RawIterableType(RawRecordType(atts, _, _), _, _) = t @@ -69,13 +73,19 @@ class TypedResultSetCsvWriter(os: OutputStream, lineSeparator: String) { } gen.setSchema(schemaBuilder.build()) gen.enable(STRICT_CHECK_FOR_QUOTING) - while (resultSet.next()) { - gen.writeStartObject() - for (i <- 0 until distincted.size()) { - gen.writeFieldName(distincted.get(i)) - writeValue(resultSet, i + 1, atts(i).tipe) + var rowsWritten = 0L + while (resultSet.next() && !maxRowsReached) { + if (maxRows.isDefined && rowsWritten >= maxRows.get) { + maxRowsReached = true + } else { + gen.writeStartObject() + for (i <- 0 until distincted.size()) { + gen.writeFieldName(distincted.get(i)) + writeValue(resultSet, i + 1, atts(i).tipe) + } + gen.writeEndObject() + rowsWritten += 1 } - gen.writeEndObject() } } diff --git a/sql-client/src/main/scala/raw/client/sql/writers/TypedResultSetJsonWriter.scala b/sql-client/src/main/scala/raw/client/sql/writers/TypedResultSetJsonWriter.scala index 097c095d3..7b9da6744 100644 --- a/sql-client/src/main/scala/raw/client/sql/writers/TypedResultSetJsonWriter.scala +++ b/sql-client/src/main/scala/raw/client/sql/writers/TypedResultSetJsonWriter.scala @@ -33,7 +33,7 @@ object TypedResultSetJsonWriter { } -class TypedResultSetJsonWriter(os: OutputStream) { +class TypedResultSetJsonWriter(os: OutputStream, maxRows: Option[Long]) { final private val gen = try { @@ -49,6 +49,10 @@ class TypedResultSetJsonWriter(os: OutputStream) { final private val timestampFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS") final private val mapper = new ObjectMapper(); + private var maxRowsReached = false + + def complete: Boolean = !maxRowsReached + @throws[IOException] def write(resultSet: ResultSet, t: RawType): Unit = { val RawIterableType(RawRecordType(atts, _, _), _, _) = t @@ -56,15 +60,21 @@ class TypedResultSetJsonWriter(os: OutputStream) { atts.foreach(a => keys.add(a.idn)) val distincted = RecordFieldsNaming.makeDistinct(keys) gen.writeStartArray() - while (resultSet.next()) { - gen.writeStartObject() - for (i <- 0 until distincted.size()) { - val field = distincted.get(i) - val t = atts(i).tipe - gen.writeFieldName(field) - writeValue(resultSet, i + 1, t) + var rowsWritten = 0L + while (resultSet.next() && !maxRowsReached) { + if (maxRows.isDefined && rowsWritten >= maxRows.get) { + maxRowsReached = true + } else { + gen.writeStartObject() + for (i <- 0 until distincted.size()) { + val field = distincted.get(i) + val t = atts(i).tipe + gen.writeFieldName(field) + writeValue(resultSet, i + 1, t) + } + gen.writeEndObject() + rowsWritten += 1 } - gen.writeEndObject() } gen.writeEndArray() } diff --git a/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala b/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala index c94b1dd61..49eb17542 100644 --- a/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala +++ b/sql-client/src/test/scala/raw/client/sql/TestSqlCompilerServiceAirports.scala @@ -371,7 +371,7 @@ class TestSqlCompilerServiceAirports asJson(), None, baos - ) == ExecutionSuccess + ) == ExecutionSuccess(true) ) assert( baos.toString() == @@ -416,7 +416,7 @@ class TestSqlCompilerServiceAirports environment, None, baos - ) == ExecutionSuccess + ) == ExecutionSuccess(true) ) assert( baos.toString() == @@ -448,7 +448,7 @@ class TestSqlCompilerServiceAirports environment, None, baos - ) == ExecutionSuccess + ) == ExecutionSuccess(true) ) assert(baos.toString() == "[]") } @@ -483,7 +483,7 @@ class TestSqlCompilerServiceAirports environment, None, baos - ) == ExecutionSuccess + ) == ExecutionSuccess(true) ) assert(baos.toString() == """[{"n":6}]""") } @@ -601,7 +601,7 @@ class TestSqlCompilerServiceAirports assert(!description.maybeRunnable.get.params.get.head.required) assert(description.maybeRunnable.get.params.get.head.defaultValue.contains(RawInt(2))) val baos = new ByteArrayOutputStream() - assert(compilerService.execute(t.q, asJson(), None, baos) == ExecutionSuccess) + assert(compilerService.execute(t.q, asJson(), None, baos) == ExecutionSuccess(true)) assert(baos.toString() == """[{"v":4}]""") } @@ -665,13 +665,13 @@ class TestSqlCompilerServiceAirports assert(param.defaultValue.isEmpty) val baos = new ByteArrayOutputStream() baos.reset() - assert(compilerService.execute(t.q, withCity, None, baos) == ExecutionSuccess) + assert(compilerService.execute(t.q, withCity, None, baos) == ExecutionSuccess(true)) assert(baos.toString() == """[{"count":1}]""") baos.reset() - assert(compilerService.execute(t.q, withNull, None, baos) == ExecutionSuccess) + assert(compilerService.execute(t.q, withNull, None, baos) == ExecutionSuccess(true)) assert(baos.toString() == """[{"count":0}]""") baos.reset() - assert(compilerService.execute(t.q, withCountry, None, baos) == ExecutionSuccess) + assert(compilerService.execute(t.q, withCountry, None, baos) == ExecutionSuccess(true)) assert(baos.toString() == """[{"count":39}]""") } @@ -699,10 +699,10 @@ class TestSqlCompilerServiceAirports assert(param.defaultValue.isEmpty) val baos = new ByteArrayOutputStream() baos.reset() - assert(compilerService.execute(t.q, withCity, None, baos) == ExecutionSuccess) + assert(compilerService.execute(t.q, withCity, None, baos) == ExecutionSuccess(true)) assert(baos.toString() == """[{"count":1}]""") baos.reset() - assert(compilerService.execute(t.q, withNull, None, baos) == ExecutionSuccess) + assert(compilerService.execute(t.q, withNull, None, baos) == ExecutionSuccess(true)) assert(baos.toString() == """[{"count":3}]""") } @@ -775,7 +775,7 @@ class TestSqlCompilerServiceAirports asJson(), None, baos - ) == ExecutionSuccess + ) == ExecutionSuccess(true) ) assert( baos.toString() == @@ -806,7 +806,7 @@ class TestSqlCompilerServiceAirports val baos = new ByteArrayOutputStream() baos.reset() val noParam = asJson() - assert(compilerService.execute(t.q, noParam, None, baos) == ExecutionSuccess) + assert(compilerService.execute(t.q, noParam, None, baos) == ExecutionSuccess(true)) assert(baos.toString() == """[{"count":8107}]""") } @@ -818,7 +818,7 @@ class TestSqlCompilerServiceAirports val baos = new ByteArrayOutputStream() baos.reset() val noParam = asJson() - assert(compilerService.execute(t.q, noParam, None, baos) == ExecutionSuccess) + assert(compilerService.execute(t.q, noParam, None, baos) == ExecutionSuccess(true)) assert(baos.toString() == """[{"count":8107}]""") } @@ -832,7 +832,7 @@ class TestSqlCompilerServiceAirports val baos = new ByteArrayOutputStream() baos.reset() val noParam = asJson() - assert(compilerService.execute(t.q, noParam, None, baos) == ExecutionSuccess) + assert(compilerService.execute(t.q, noParam, None, baos) == ExecutionSuccess(true)) assert(baos.toString() == """[{"count":8107}]""") } @@ -899,7 +899,7 @@ class TestSqlCompilerServiceAirports assert(compilerService.validate(q, env).messages.isEmpty) val GetProgramDescriptionSuccess(_) = compilerService.getProgramDescription(q, env) baos.reset() - assert(compilerService.execute(q, env, None, baos) == ExecutionSuccess) + assert(compilerService.execute(q, env, None, baos) == ExecutionSuccess(true)) baos.toString } // assert(runWith("SELECT e.airport_id FROM example.airports e", Set.empty) == """[]""")