Skip to content

Commit

Permalink
fixed tests
Browse files Browse the repository at this point in the history
  • Loading branch information
alexzerntev committed Jan 26, 2024
1 parent a8e3603 commit 7cd7dc6
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ import raw.sources.bytestream.api.HttpLocationsTestContext

trait FunctionPackageTest extends CompilerTestContext with HttpLocationsTestContext {

test("""test() = 1 + 1
|Function.InvokeAfter(() -> test(), 2)""".stripMargin)(_ should evaluateTo("2"))
test("""Function.InvokeAfter(() -> 1 +1, 10)""".stripMargin)(_ should evaluateTo("2"))

test("invoke function after 5 seconds") { _ =>
val start = System.nanoTime()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import raw.compiler.rql2.source.*;
import raw.compiler.snapi.truffle.TruffleEmitter;
import raw.runtime.truffle.ExpressionNode;
import raw.runtime.truffle.RawContext;
import raw.runtime.truffle.RawLanguage;
import raw.runtime.truffle.StatementNode;
import raw.runtime.truffle.ast.ProgramExpressionNode;
Expand All @@ -40,7 +39,8 @@
import raw.runtime.truffle.ast.expressions.binary.MultNodeGen;
import raw.runtime.truffle.ast.expressions.binary.SubNodeGen;
import raw.runtime.truffle.ast.expressions.function.ClosureNode;
import raw.runtime.truffle.ast.expressions.function.InvokeNode;
import raw.runtime.truffle.ast.expressions.function.InvokeWithNamesNode;
import raw.runtime.truffle.ast.expressions.function.LambdaNode;
import raw.runtime.truffle.ast.expressions.function.MethodNode;
import raw.runtime.truffle.ast.expressions.literals.*;
import raw.runtime.truffle.ast.expressions.option.OptionNoneNode;
Expand Down Expand Up @@ -128,7 +128,7 @@ protected StatementNode emitMethod(Rql2Method m) {
ExpressionNode[] defaultArgs = JavaConverters.asJavaCollection(fp.ps()).stream()
.map(p -> p.e().isDefined() ? recurseExp(p.e().get()) : null)
.toArray(ExpressionNode[]::new);
MethodNode functionLiteralNode = new MethodNode(f, defaultArgs);
MethodNode functionLiteralNode = new MethodNode(m.i().idn(),f, defaultArgs);
int slot = getFrameDescriptorBuilder().addSlot(FrameSlotKind.Object, getIdnName(entity), null);
addSlot(entity, Integer.toString(slot));
return WriteLocalVariableNodeGen.create(functionLiteralNode, slot, null);
Expand Down Expand Up @@ -364,11 +364,17 @@ yield switch (entity) {
yield new ExpBlockNode(decls, recurseExp(let.e()));
}
case FunAbs fa -> {
Function f = recurseFunProto(fa.p());
ExpressionNode[] defaultArgs = JavaConverters.asJavaCollection(fa.p().ps()).stream()
.map(p -> p.e().isDefined() ? recurseExp(p.e().get()) : null)
.toArray(ExpressionNode[]::new);
yield new ClosureNode(f, defaultArgs);
if (analyzer.freeVars(fa).isEmpty() && fa.p().ps().forall(p -> p.t().isEmpty())) {
Function f = recurseFunProto(fa.p());
yield new LambdaNode(f);
}
else{
Function f = recurseFunProto(fa.p());
ExpressionNode[] defaultArgs = JavaConverters.asJavaCollection(fa.p().ps()).stream()
.map(p -> p.e().isDefined() ? recurseExp(p.e().get()) : null)
.toArray(ExpressionNode[]::new);
yield new ClosureNode(f, defaultArgs);
}
}
case FunApp fa when tipe(fa.f()) instanceof PackageEntryType -> {
Type t = tipe(fa);
Expand All @@ -388,7 +394,7 @@ case FunApp fa when tipe(fa.f()) instanceof PackageEntryType -> {
case FunApp fa -> {
String[] argNames = JavaConverters.asJavaCollection(fa.args()).stream().map(a -> a.idn().isDefined() ? a.idn().get() : null).toArray(String[]::new);
ExpressionNode[] exps = JavaConverters.asJavaCollection(fa.args()).stream().map(a -> recurseExp(a.e())).toArray(ExpressionNode[]::new);
yield new InvokeNode(recurseExp(fa.f()), argNames, exps);
yield new InvokeWithNamesNode(recurseExp(fa.f()), argNames, exps);
}
default -> throw new RawTruffleInternalErrorException("Unknown expression type");
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,36 @@
import raw.runtime.truffle.ExpressionNode;
import raw.runtime.truffle.runtime.function.FunctionExecuteNodes;
import raw.runtime.truffle.runtime.function.FunctionExecuteNodesFactory;
import raw.runtime.truffle.runtime.function.Lambda;

@NodeInfo(shortName = "invoke")
public final class InvokeNode extends ExpressionNode {
@NodeInfo(shortName = "InvokeWithNames")
public final class InvokeWithNamesNode extends ExpressionNode {

@Child private ExpressionNode functionNode;

@Child
private FunctionExecuteNodes.FunctionExecuteWithNames functionExec =
private FunctionExecuteNodes.FunctionExecuteWithNames functionExecWithNames =
FunctionExecuteNodesFactory.FunctionExecuteWithNamesNodeGen.create();

@Child
private FunctionExecuteNodes.FunctionExecuteZero functionExecZero =
FunctionExecuteNodesFactory.FunctionExecuteZeroNodeGen.create();

@Child
private FunctionExecuteNodes.FunctionExecuteOne functionExecOne =
FunctionExecuteNodesFactory.FunctionExecuteOneNodeGen.create();

@Child
private FunctionExecuteNodes.FunctionExecuteTwo functionExecTwo =
FunctionExecuteNodesFactory.FunctionExecuteTwoNodeGen.create();

@Children private final ExpressionNode[] argumentNodes;

private final Object[] argumentValues;

private final String[] argNames;

public InvokeNode(
public InvokeWithNamesNode(
ExpressionNode functionNode, String[] argNames, ExpressionNode[] argumentNodes) {
this.functionNode = functionNode;
assert (argNames.length == argumentNodes.length);
Expand All @@ -50,12 +63,24 @@ public InvokeNode(
@Override
public Object executeGeneric(VirtualFrame frame) {
CompilerAsserts.compilationConstant(argumentNodes.length);

Object function = functionNode.executeGeneric(frame);
if (function instanceof Lambda) {
if (argNames.length == 0) {
return functionExecZero.execute(this, function);
} else if (argNames.length == 1) {
return functionExecOne.execute(this, function, argumentNodes[0].executeGeneric(frame));
} else if (argNames.length == 2) {
return functionExecTwo.execute(
this,
function,
argumentNodes[0].executeGeneric(frame),
argumentNodes[1].executeGeneric(frame));
}
}
for (int i = 0; i < argumentNodes.length; i++) {
argumentValues[i] = argumentNodes[i].executeGeneric(frame);
}
return functionExec.execute(this, function, argNames, argumentValues);
return functionExecWithNames.execute(this, function, argNames, argumentValues);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
package raw.runtime.truffle.ast.expressions.function;

public class LambdaNode {}
import com.oracle.truffle.api.CompilerDirectives;
import com.oracle.truffle.api.RootCallTarget;
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import raw.runtime.truffle.ExpressionNode;
import raw.runtime.truffle.runtime.function.Function;
import raw.runtime.truffle.runtime.function.Lambda;

public class LambdaNode extends ExpressionNode {

@CompilerDirectives.CompilationFinal private final RootCallTarget callTarget;

public LambdaNode(Function f) {
callTarget = f.getCallTarget();
}

@Override
@ExplodeLoop
public Object executeGeneric(VirtualFrame virtualFrame) {
return new Lambda(callTarget, virtualFrame);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.Node;
import raw.runtime.truffle.ExpressionNode;
import raw.runtime.truffle.RawContext;
import raw.runtime.truffle.runtime.function.Function;
import raw.runtime.truffle.runtime.function.NonClosure;

Expand All @@ -29,9 +30,12 @@ public final class MethodNode extends ExpressionNode {
@CompilerDirectives.CompilationFinal private NonClosure nonClosure;
@Node.Children private final ExpressionNode[] defaultArgumentExps;

public MethodNode(Function f, ExpressionNode[] defaultArgumentExps) {
private final String name;

public MethodNode(String name, Function f, ExpressionNode[] defaultArgumentExps) {
this.function = f;
this.defaultArgumentExps = defaultArgumentExps;
this.name = name;
}

@Override
Expand All @@ -49,6 +53,7 @@ public Object executeGeneric(VirtualFrame virtualFrame) {
}
nonClosure = new NonClosure(this.function, defaultArguments, virtualFrame);
}
RawContext.get(this).getFunctionRegistry().register(name, nonClosure);
}
return nonClosure;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ public String[] getArgNames() {
@ExportMessage
abstract static class Execute {

// Expects virtual frame to be already set
@Specialization(
limit = "INLINE_CACHE_SIZE",
guards = "nonClosure.getCallTarget() == cachedTarget")
Expand All @@ -61,15 +60,24 @@ protected static Object doDirect(
Object[] arguments,
@Cached("nonClosure.getCallTarget()") RootCallTarget cachedTarget,
@Cached("create(cachedTarget)") DirectCallNode callNode) {
setArgs(nonClosure, arguments, arguments);
return callNode.call(arguments);
Object[] finalArgs = new Object[nonClosure.getArgNames().length + 1];
finalArgs[0] = nonClosure.frame;
System.arraycopy(
nonClosure.defaultArguments, 0, finalArgs, 1, nonClosure.getArgNames().length);
setArgs(nonClosure, arguments, finalArgs);
return callNode.call(finalArgs);
}

@Specialization(replaces = "doDirect")
protected static Object doIndirect(
NonClosure nonClosure, Object[] arguments, @Cached IndirectCallNode callNode) {
setArgs(nonClosure, arguments, arguments);
return callNode.call(nonClosure.getCallTarget(), arguments);
Object[] finalArgs = new Object[nonClosure.getArgNames().length + 1];
finalArgs[0] = nonClosure.frame;
System.arraycopy(
nonClosure.defaultArguments, 0, finalArgs, 1, nonClosure.getArgNames().length);
setArgs(nonClosure, arguments, finalArgs);

return callNode.call(nonClosure.getCallTarget(), finalArgs);
}

private static void setArgs(NonClosure nonClosure, Object[] arguments, Object[] finalArgs) {
Expand Down

0 comments on commit 7cd7dc6

Please sign in to comment.