Skip to content

Commit

Permalink
Simplify RawValue (#470)
Browse files Browse the repository at this point in the history
* Remove "snapi-only" types from the raw interface (e.g.
RawLocationType).
* Instead, use directly Rql2Value for the staged compiler.

---------

Co-authored-by: Benjamin Gaidioz <ben@raw-labs.com>
  • Loading branch information
miguelbranco80 and bgaidioz authored Jul 30, 2024
1 parent 7e177b2 commit 87bde80
Show file tree
Hide file tree
Showing 20 changed files with 298 additions and 383 deletions.
134 changes: 1 addition & 133 deletions client/src/main/scala/raw/client/api/CompilerService.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
package raw.client.api

import com.fasterxml.jackson.annotation.{JsonSubTypes, JsonTypeInfo}
import org.graalvm.polyglot.{Engine, Value}
import org.graalvm.polyglot.Engine
import raw.utils.{RawException, RawService, RawSettings}

import java.io.OutputStream
import scala.collection.mutable
import scala.util.control.NonFatal
import com.fasterxml.jackson.annotation.JsonSubTypes.{Type => JsonType}

// Exception that wraps the underlying error so that it includes the extra debug info.
Expand Down Expand Up @@ -89,137 +88,6 @@ object CompilerService {
)
}

def polyglotValueToRawValue(v: Value, t: RawType): RawValue = {
if (t.triable) {
if (v.isException) {
try {
v.throwException()
throw new AssertionError("should not happen")
} catch {
case NonFatal(ex) => RawError(ex.getMessage)
}
} else {
// Success, recurse without the tryable property.
polyglotValueToRawValue(v, t.cloneNotTriable)
}
} else if (t.nullable) {
if (v.isNull) {
RawNull()
} else {
polyglotValueToRawValue(v, t.cloneNotNullable)
}
} else {
t match {
case _: RawUndefinedType => throw new AssertionError("RawUndefined is not triable and is not nullable.")
case _: RawAnyType => RawAny(v)
case _: RawBoolType => RawBool(v.asBoolean())
case _: RawStringType => RawString(v.asString())
case _: RawByteType => RawByte(v.asByte())
case _: RawShortType => RawShort(v.asShort())
case _: RawIntType => RawInt(v.asInt())
case _: RawLongType => RawLong(v.asLong())
case _: RawFloatType => RawFloat(v.asFloat())
case _: RawDoubleType => RawDouble(v.asDouble())
case _: RawDecimalType =>
val bg = BigDecimal(v.asString())
RawDecimal(bg.bigDecimal)
case _: RawDateType =>
val date = v.asDate()
RawDate(date)
case _: RawTimeType =>
val time = v.asTime()
RawTime(time)
case _: RawTimestampType =>
val localDate = v.asDate()
val localTime = v.asTime()
RawTimestamp(localDate.atTime(localTime))
case _: RawIntervalType =>
val d = v.asDuration()
RawInterval(0, 0, 0, d.toDaysPart.toInt, d.toHoursPart, d.toMinutesPart, d.toSecondsPart, d.toMillisPart)
case _: RawBinaryType =>
val bufferSize = v.getBufferSize.toInt
val byteArray = new Array[Byte](bufferSize)
for (i <- 0 until bufferSize) {
byteArray(i) = v.readBufferByte(i)
}
RawBinary(byteArray)
case RawRecordType(atts, _, _) =>
val vs = atts.map(att => polyglotValueToRawValue(v.getMember(att.idn), att.tipe))
RawRecord(vs)
case RawListType(innerType, _, _) =>
val seq = mutable.ArrayBuffer[RawValue]()
for (i <- 0L until v.getArraySize) {
val v1 = v.getArrayElement(i)
seq.append(polyglotValueToRawValue(v1, innerType))
}
RawList(seq)
case RawIterableType(innerType, _, _) =>
val seq = mutable.ArrayBuffer[RawValue]()
val it = v.getIterator
while (it.hasIteratorNextElement) {
val v1 = it.getIteratorNextElement
seq.append(polyglotValueToRawValue(v1, innerType))
}
if (it.canInvokeMember("close")) {
val callable = it.getMember("close")
callable.execute()
}
RawIterable(seq)
case RawOrType(tipes, _, _) =>
val idx = v.getMember("index").asInt()
val v1 = v.getMember("value")
val tipe = tipes(idx)
polyglotValueToRawValue(v1, tipe)
case _: RawLocationType =>
val url = v.asString
assert(v.hasMembers);
val members = v.getMemberKeys
val settings = mutable.Map.empty[LocationSettingKey, LocationSettingValue]
val keys = members.iterator()
while (keys.hasNext) {
val key = keys.next()
val tv = v.getMember(key)
val value =
if (tv.isNumber) LocationIntSetting(tv.asInt)
else if (tv.isBoolean) LocationBooleanSetting(tv.asBoolean)
else if (tv.isString) LocationStringSetting(tv.asString)
else if (tv.hasBufferElements) {
val bufferSize = tv.getBufferSize.toInt
val byteArray = new Array[Byte](bufferSize)
for (i <- 0 until bufferSize) {
byteArray(i) = tv.readBufferByte(i)
}
LocationBinarySetting(byteArray)
} else if (tv.isDuration) LocationDurationSetting(tv.asDuration())
else if (tv.hasArrayElements) {
// in the context of a location, it's int-array for sure
val size = tv.getArraySize
val array = new Array[Int](size.toInt)
for (i <- 0L until size) {
array(i.toInt) = tv.getArrayElement(i).asInt
}
LocationIntArraySetting(array)
} else if (tv.hasHashEntries) {
// kv settings
val iterator = tv.getHashEntriesIterator
val keyValues = mutable.ArrayBuffer.empty[(String, String)]
while (iterator.hasIteratorNextElement) {
val kv = iterator.getIteratorNextElement // array with two elements: key and value
val key = kv.getArrayElement(0).asString
val value = kv.getArrayElement(1).asString
keyValues += ((key, value))
}
LocationKVSetting(keyValues)
} else {
throw new AssertionError("Unexpected value type: " + tv)
}
settings.put(LocationSettingKey(key), value)
}
RawLocation(LocationDescription(url, settings.toMap))
}
}
}

}

trait CompilerService extends RawService {
Expand Down
14 changes: 1 addition & 13 deletions client/src/main/scala/raw/client/api/RawValues.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import com.fasterxml.jackson.annotation.{JsonSubTypes, JsonTypeInfo}
@JsonSubTypes(
Array(
new JsonType(value = classOf[RawNull], name = "null"),
new JsonType(value = classOf[RawError], name = "error"),
new JsonType(value = classOf[RawByte], name = "byte"),
new JsonType(value = classOf[RawShort], name = "short"),
new JsonType(value = classOf[RawInt], name = "int"),
Expand All @@ -30,21 +29,15 @@ import com.fasterxml.jackson.annotation.{JsonSubTypes, JsonTypeInfo}
new JsonType(value = classOf[RawBool], name = "bool"),
new JsonType(value = classOf[RawString], name = "string"),
new JsonType(value = classOf[RawBinary], name = "binary"),
new JsonType(value = classOf[RawLocation], name = "location"),
new JsonType(value = classOf[RawDate], name = "date"),
new JsonType(value = classOf[RawTime], name = "time"),
new JsonType(value = classOf[RawTimestamp], name = "timestamp"),
new JsonType(value = classOf[RawInterval], name = "interval"),
new JsonType(value = classOf[RawRecord], name = "record"),
new JsonType(value = classOf[RawList], name = "list"),
new JsonType(value = classOf[RawIterable], name = "iterable"),
new JsonType(value = classOf[RawOr], name = "or")
new JsonType(value = classOf[RawInterval], name = "interval")
)
)
sealed trait RawValue
final case class RawAny(v: Any) extends RawValue
final case class RawNull() extends RawValue
final case class RawError(v: String) extends RawValue
final case class RawByte(v: java.lang.Byte) extends RawValue
final case class RawShort(v: java.lang.Short) extends RawValue
final case class RawInt(v: java.lang.Integer) extends RawValue
Expand All @@ -55,7 +48,6 @@ final case class RawDecimal(v: java.math.BigDecimal) extends RawValue
final case class RawBool(v: java.lang.Boolean) extends RawValue
final case class RawString(v: java.lang.String) extends RawValue
final case class RawBinary(v: Array[Byte]) extends RawValue
final case class RawLocation(v: LocationDescription) extends RawValue
final case class RawDate(v: java.time.LocalDate) extends RawValue
final case class RawTime(v: java.time.LocalTime) extends RawValue
final case class RawTimestamp(v: java.time.LocalDateTime) extends RawValue
Expand All @@ -69,7 +61,3 @@ final case class RawInterval(
seconds: Int,
millis: Int
) extends RawValue
final case class RawRecord(v: Seq[RawValue]) extends RawValue
final case class RawList(v: Seq[RawValue]) extends RawValue
final case class RawIterable(v: Seq[RawValue]) extends RawValue // Data has been ready is now materialized.
final case class RawOr(vs: Seq[RawValue]) extends RawValue
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
package raw.compiler.rql2

import raw.compiler.base.errors.ErrorCompilerMessage
import raw.compiler.rql2.api.{PackageExtension, PackageExtensionProvider, Value}
import raw.compiler.rql2.api.{PackageExtension, PackageExtensionProvider, Rql2Value}
import raw.compiler.rql2.source.Rql2Program
import raw.inferrer.api.{InferrerProperties, InputFormatDescriptor}

Expand All @@ -25,7 +25,7 @@ trait ProgramContext extends raw.compiler.base.ProgramContext {

private val dynamicPackageCache = new mutable.HashMap[String, PackageExtension]

private val stageCompilerCache = new mutable.HashMap[Rql2Program, Either[ErrorCompilerMessage, Value]]
private val stageCompilerCache = new mutable.HashMap[Rql2Program, Either[ErrorCompilerMessage, Rql2Value]]

def infer(
inferrerProperties: InferrerProperties
Expand All @@ -46,8 +46,8 @@ trait ProgramContext extends raw.compiler.base.ProgramContext {

def getOrAddStagedCompilation(
program: Rql2Program,
f: => Either[ErrorCompilerMessage, Value]
): Either[ErrorCompilerMessage, Value] = {
f: => Either[ErrorCompilerMessage, Rql2Value]
): Either[ErrorCompilerMessage, Rql2Value] = {
stageCompilerCache.getOrElseUpdate(program, f)
}

Expand Down
74 changes: 39 additions & 35 deletions snapi-frontend/src/main/scala/raw/compiler/rql2/Propagation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ import raw.compiler.base.source.Type
import raw.compiler.common.source._
import raw.compiler.rql2.api.{
Arg,
BoolValue,
ByteValue,
DateValue,
DoubleValue,
FloatValue,
IntValue,
IntervalValue,
ListValue,
LongValue,
OptionValue,
RecordValue,
ShortValue,
StringValue,
TimeValue,
TimestampValue,
Value,
Rql2BoolValue,
Rql2ByteValue,
Rql2DateValue,
Rql2DoubleValue,
Rql2FloatValue,
Rql2IntValue,
Rql2IntervalValue,
Rql2ListValue,
Rql2LongValue,
Rql2OptionValue,
Rql2RecordValue,
Rql2ShortValue,
Rql2StringValue,
Rql2TimeValue,
Rql2TimestampValue,
Rql2Value,
ValueArg
}
import raw.compiler.rql2.builtin._
Expand All @@ -56,18 +56,22 @@ class Propagation(protected val parent: Phase[SourceProgram], protected val phas
val tree = new Tree(program)
lazy val analyzer = tree.analyzer

case class TypeAndValue(t: Type, value: Option[Value])
case class TypeAndValue(t: Type, value: Option[Rql2Value])
case class ExpProps(
ne: Exp,
t: Type,
castNeeded: Boolean,
props: Set[Rql2TypeProperty],
value: Option[Value] = None
value: Option[Rql2Value] = None
)

// Returns the type of argument at index `idx`. If the index points to a ValueArg,
// add the computed value
def getTypeAndValue(entryArguments: FunAppPackageEntryArguments, args: Seq[FunAppArg], idx: Int): Option[Value] = {
def getTypeAndValue(
entryArguments: FunAppPackageEntryArguments,
args: Seq[FunAppArg],
idx: Int
): Option[Rql2Value] = {
val arg: Arg = args(idx).idn match {
case Some(i) => entryArguments.optionalArgs.collectFirst { case a if a._1 == i => a._2 }.get
case None =>
Expand Down Expand Up @@ -342,32 +346,32 @@ class Propagation(protected val parent: Phase[SourceProgram], protected val phas
r
}

private def valueToExp(value: Value, t: Type): Exp = value match {
case ByteValue(v) => ByteConst(v.toString)
case ShortValue(v) => ShortConst(v.toString)
case IntValue(v) => IntConst(v.toString)
case LongValue(v) => LongConst(v.toString)
case FloatValue(v) => FloatConst(v.toString)
case DoubleValue(v) => DoubleConst(v.toString)
case StringValue(v) => StringConst(v)
case BoolValue(v) => BoolConst(v)
case OptionValue(option) =>
private def valueToExp(value: Rql2Value, t: Type): Exp = value match {
case Rql2ByteValue(v) => ByteConst(v.toString)
case Rql2ShortValue(v) => ShortConst(v.toString)
case Rql2IntValue(v) => IntConst(v.toString)
case Rql2LongValue(v) => LongConst(v.toString)
case Rql2FloatValue(v) => FloatConst(v.toString)
case Rql2DoubleValue(v) => DoubleConst(v.toString)
case Rql2StringValue(v) => StringConst(v)
case Rql2BoolValue(v) => BoolConst(v)
case Rql2OptionValue(option) =>
val innerType = resetProps(t, Set.empty)
option
.map(v => valueToExp(v, innerType))
.map(NullablePackageBuilder.Build(_))
.getOrElse(NullablePackageBuilder.Empty(innerType))
case RecordValue(r) =>
case Rql2RecordValue(r) =>
val Rql2RecordType(atts, _) = t
val fields = r.zip(atts).map { case (v, att) => att.idn -> valueToExp(v, att.tipe) }
RecordPackageBuilder.Build(fields.toVector)
case ListValue(v) =>
case Rql2ListValue(v) =>
val Rql2ListType(innerType, _) = t
ListPackageBuilder.Build(v.map(x => valueToExp(x, innerType)): _*)
case DateValue(v) => DatePackageBuilder.FromLocalDate(v)
case TimeValue(v) => TimePackageBuilder.FromLocalTime(v)
case TimestampValue(v) => TimestampPackageBuilder.FromLocalDateTime(v)
case IntervalValue(
case Rql2DateValue(v) => DatePackageBuilder.FromLocalDate(v)
case Rql2TimeValue(v) => TimePackageBuilder.FromLocalTime(v)
case Rql2TimestampValue(v) => TimestampPackageBuilder.FromLocalDateTime(v)
case Rql2IntervalValue(
years,
month,
weeks,
Expand Down
Loading

0 comments on commit 87bde80

Please sign in to comment.