Implement suspension of default arguments (#170)

This commit is contained in:
Ara Adkins 2019-09-05 18:26:21 +01:00 committed by GitHub
parent f8dea12e44
commit ed8223c57c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 140 additions and 45 deletions

View File

@ -12,6 +12,7 @@ import org.enso.interpreter.runtime.scope.LocalScope;
* runtime nodes used by the interpreter to guide function evaluation. * runtime nodes used by the interpreter to guide function evaluation.
*/ */
public class CallArgFactory implements AstCallArgVisitor<CallArgument> { public class CallArgFactory implements AstCallArgVisitor<CallArgument> {
private final LocalScope scope; private final LocalScope scope;
private final Language language; private final Language language;
private final String scopeName; private final String scopeName;

View File

@ -4,19 +4,42 @@ import com.oracle.truffle.api.RootCallTarget;
import com.oracle.truffle.api.Truffle; import com.oracle.truffle.api.Truffle;
import com.oracle.truffle.api.frame.FrameSlot; import com.oracle.truffle.api.frame.FrameSlot;
import com.oracle.truffle.api.nodes.RootNode; import com.oracle.truffle.api.nodes.RootNode;
import org.enso.interpreter.*; import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.enso.interpreter.AstArgDefinition;
import org.enso.interpreter.AstCallArg;
import org.enso.interpreter.AstCase;
import org.enso.interpreter.AstCaseFunction;
import org.enso.interpreter.AstExpression;
import org.enso.interpreter.AstExpressionVisitor;
import org.enso.interpreter.Language;
import org.enso.interpreter.node.EnsoRootNode; import org.enso.interpreter.node.EnsoRootNode;
import org.enso.interpreter.node.ExpressionNode; import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.node.callable.InvokeCallableNodeGen; import org.enso.interpreter.node.callable.InvokeCallableNodeGen;
import org.enso.interpreter.node.callable.argument.ReadArgumentNode; import org.enso.interpreter.node.callable.argument.ReadArgumentNode;
import org.enso.interpreter.node.callable.function.CreateFunctionNode; import org.enso.interpreter.node.callable.function.CreateFunctionNode;
import org.enso.interpreter.node.callable.function.FunctionBodyNode; import org.enso.interpreter.node.callable.function.FunctionBodyNode;
import org.enso.interpreter.node.controlflow.*; import org.enso.interpreter.node.controlflow.CaseNode;
import org.enso.interpreter.node.controlflow.ConstructorCaseNode;
import org.enso.interpreter.node.controlflow.DefaultFallbackNode;
import org.enso.interpreter.node.controlflow.FallbackNode;
import org.enso.interpreter.node.controlflow.IfZeroNode;
import org.enso.interpreter.node.controlflow.MatchNode;
import org.enso.interpreter.node.expression.builtin.PrintNode; import org.enso.interpreter.node.expression.builtin.PrintNode;
import org.enso.interpreter.node.expression.constant.ConstructorNode; import org.enso.interpreter.node.expression.constant.ConstructorNode;
import org.enso.interpreter.node.expression.constant.DynamicSymbolNode; import org.enso.interpreter.node.expression.constant.DynamicSymbolNode;
import org.enso.interpreter.node.expression.literal.IntegerLiteralNode; import org.enso.interpreter.node.expression.literal.IntegerLiteralNode;
import org.enso.interpreter.node.expression.operator.*; import org.enso.interpreter.node.expression.operator.AddOperatorNodeGen;
import org.enso.interpreter.node.expression.operator.DivideOperatorNodeGen;
import org.enso.interpreter.node.expression.operator.ModOperatorNodeGen;
import org.enso.interpreter.node.expression.operator.MultiplyOperatorNodeGen;
import org.enso.interpreter.node.expression.operator.SubtractOperatorNodeGen;
import org.enso.interpreter.node.scope.AssignmentNode; import org.enso.interpreter.node.scope.AssignmentNode;
import org.enso.interpreter.node.scope.AssignmentNodeGen; import org.enso.interpreter.node.scope.AssignmentNodeGen;
import org.enso.interpreter.node.scope.ReadLocalTargetNodeGen; import org.enso.interpreter.node.scope.ReadLocalTargetNodeGen;
@ -27,16 +50,12 @@ import org.enso.interpreter.runtime.error.DuplicateArgumentNameException;
import org.enso.interpreter.runtime.scope.ModuleScope; import org.enso.interpreter.runtime.scope.ModuleScope;
import org.enso.interpreter.runtime.scope.LocalScope; import org.enso.interpreter.runtime.scope.LocalScope;
import java.util.*;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/** /**
* An {@code ExpressionFactory} is responsible for converting the majority of Enso's parsed AST into * An {@code ExpressionFactory} is responsible for converting the majority of Enso's parsed AST into
* nodes evaluated by the interpreter at runtime. * nodes evaluated by the interpreter at runtime.
*/ */
public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> { public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
private final LocalScope scope; private final LocalScope scope;
private final Language language; private final Language language;
private final String scopeName; private final String scopeName;
@ -128,11 +147,21 @@ public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
String operator, AstExpression leftAst, AstExpression rightAst) { String operator, AstExpression leftAst, AstExpression rightAst) {
ExpressionNode left = leftAst.visit(this); ExpressionNode left = leftAst.visit(this);
ExpressionNode right = rightAst.visit(this); ExpressionNode right = rightAst.visit(this);
if (operator.equals("+")) return AddOperatorNodeGen.create(left, right); if (operator.equals("+")) {
if (operator.equals("-")) return SubtractOperatorNodeGen.create(left, right); return AddOperatorNodeGen.create(left, right);
if (operator.equals("*")) return MultiplyOperatorNodeGen.create(left, right); }
if (operator.equals("/")) return DivideOperatorNodeGen.create(left, right); if (operator.equals("-")) {
if (operator.equals("%")) return ModOperatorNodeGen.create(left, right); return SubtractOperatorNodeGen.create(left, right);
}
if (operator.equals("*")) {
return MultiplyOperatorNodeGen.create(left, right);
}
if (operator.equals("/")) {
return DivideOperatorNodeGen.create(left, right);
}
if (operator.equals("%")) {
return ModOperatorNodeGen.create(left, right);
}
return null; return null;
} }
@ -300,7 +329,7 @@ public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
*/ */
@Override @Override
public ExpressionNode visitFunctionApplication( public ExpressionNode visitFunctionApplication(
AstExpression function, List<AstCallArg> arguments) { AstExpression function, List<AstCallArg> arguments, boolean hasDefaultsSuspended) {
CallArgFactory argFactory = new CallArgFactory(scope, language, scopeName, moduleScope); CallArgFactory argFactory = new CallArgFactory(scope, language, scopeName, moduleScope);
List<CallArgument> callArgs = new ArrayList<>(); List<CallArgument> callArgs = new ArrayList<>();
@ -310,7 +339,7 @@ public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
} }
return InvokeCallableNodeGen.create( return InvokeCallableNodeGen.create(
callArgs.stream().toArray(CallArgument[]::new), function.visit(this)); callArgs.toArray(new CallArgument[0]), hasDefaultsSuspended, function.visit(this));
} }
/** /**

View File

@ -8,6 +8,7 @@ import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop; import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.NodeInfo; import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.profiles.ConditionProfile; import com.oracle.truffle.api.profiles.ConditionProfile;
import java.util.Arrays;
import org.enso.interpreter.Constants; import org.enso.interpreter.Constants;
import org.enso.interpreter.node.ExpressionNode; import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.node.callable.argument.sorter.ArgumentSorterNode; import org.enso.interpreter.node.callable.argument.sorter.ArgumentSorterNode;
@ -22,8 +23,6 @@ import org.enso.interpreter.runtime.error.MethodDoesNotExistException;
import org.enso.interpreter.runtime.error.NotInvokableException; import org.enso.interpreter.runtime.error.NotInvokableException;
import org.enso.interpreter.runtime.type.TypesGen; import org.enso.interpreter.runtime.type.TypesGen;
import java.util.Arrays;
/** /**
* This node is responsible for organising callable calls so that they are ready to be made. * This node is responsible for organising callable calls so that they are ready to be made.
* *
@ -33,15 +32,17 @@ import java.util.Arrays;
@NodeInfo(shortName = "@", description = "Executes function") @NodeInfo(shortName = "@", description = "Executes function")
@NodeChild(value = "callable", type = ExpressionNode.class) @NodeChild(value = "callable", type = ExpressionNode.class)
public abstract class InvokeCallableNode extends ExpressionNode { public abstract class InvokeCallableNode extends ExpressionNode {
@Children
private @CompilationFinal(dimensions = 1) ExpressionNode[] argExpressions;
private final boolean canApplyThis; @Children
private final int thisArgumentPosition; @CompilationFinal(dimensions = 1)
private ExpressionNode[] argExpressions;
@Child private ArgumentSorterNode argumentSorter; @Child private ArgumentSorterNode argumentSorter;
@Child private MethodResolverNode methodResolverNode; @Child private MethodResolverNode methodResolverNode;
private final boolean canApplyThis;
private final int thisArgumentPosition;
private final ConditionProfile methodCalledOnNonAtom = ConditionProfile.createCountingProfile(); private final ConditionProfile methodCalledOnNonAtom = ConditionProfile.createCountingProfile();
/** /**
@ -49,7 +50,7 @@ public abstract class InvokeCallableNode extends ExpressionNode {
* *
* @param callArguments information on the arguments being passed to the {@link Function} * @param callArguments information on the arguments being passed to the {@link Function}
*/ */
public InvokeCallableNode(CallArgument[] callArguments) { public InvokeCallableNode(CallArgument[] callArguments, boolean hasDefaultsSuspended) {
this.argExpressions = this.argExpressions =
Arrays.stream(callArguments) Arrays.stream(callArguments)
.map(CallArgument::getExpression) .map(CallArgument::getExpression)
@ -57,9 +58,9 @@ public abstract class InvokeCallableNode extends ExpressionNode {
CallArgumentInfo[] argSchema = CallArgumentInfo[] argSchema =
Arrays.stream(callArguments).map(CallArgumentInfo::new).toArray(CallArgumentInfo[]::new); Arrays.stream(callArguments).map(CallArgumentInfo::new).toArray(CallArgumentInfo[]::new);
boolean appliesThis = false; boolean appliesThis = false;
int idx = 0; int idx = 0;
for (; idx < argSchema.length; idx++) { for (; idx < argSchema.length; idx++) {
CallArgumentInfo arg = argSchema[idx]; CallArgumentInfo arg = argSchema[idx];
if (arg.isPositional() if (arg.isPositional()
@ -68,15 +69,17 @@ public abstract class InvokeCallableNode extends ExpressionNode {
break; break;
} }
} }
this.canApplyThis = appliesThis; this.canApplyThis = appliesThis;
this.thisArgumentPosition = idx; this.thisArgumentPosition = idx;
this.argumentSorter = ArgumentSorterNodeGen.create(argSchema); this.argumentSorter = ArgumentSorterNodeGen.create(argSchema, hasDefaultsSuspended);
this.methodResolverNode = MethodResolverNodeGen.create(); this.methodResolverNode = MethodResolverNodeGen.create();
} }
/** /**
* Marks whether the {@code argumentSorter} child is tailrecursive. * Marks whether the {@code argumentSorter} child is tailrecursive.
*
* @param isTail whether or not the node is tail-recursive. * @param isTail whether or not the node is tail-recursive.
*/ */
@Override @Override

View File

@ -16,18 +16,20 @@ import org.enso.interpreter.runtime.callable.function.Function;
*/ */
@NodeInfo(shortName = "ArgumentSorter") @NodeInfo(shortName = "ArgumentSorter")
public abstract class ArgumentSorterNode extends BaseNode { public abstract class ArgumentSorterNode extends BaseNode {
private @CompilationFinal(dimensions = 1) CallArgumentInfo[] schema; private @CompilationFinal(dimensions = 1) CallArgumentInfo[] schema;
private final boolean hasDefaultsSuspended;
/** /**
* Creates a node that performs the argument organisation for the provided schema. * Creates a node that performs the argument organisation for the provided schema.
* *
* @param schema information about the call arguments in positional order * @param schema information about the call arguments in positional order
*/ */
public ArgumentSorterNode(CallArgumentInfo[] schema) { public ArgumentSorterNode(CallArgumentInfo[] schema, boolean hasDefaultsSuspended) {
this.schema = schema; this.schema = schema;
this.hasDefaultsSuspended = hasDefaultsSuspended;
} }
/** /**
* Generates the argument mapping where it has already been computed and executes the function. * Generates the argument mapping where it has already been computed and executes the function.
* *
@ -49,7 +51,8 @@ public abstract class ArgumentSorterNode extends BaseNode {
public Object invokeCached( public Object invokeCached(
Function function, Function function,
Object[] arguments, Object[] arguments,
@Cached("create(function, getSchema())") CachedArgumentSorterNode mappingNode, @Cached("create(function, getSchema(), hasDefaultsSuspended())")
CachedArgumentSorterNode mappingNode,
@Cached CallOptimiserNode optimiser) { @Cached CallOptimiserNode optimiser) {
Object[] mappedArguments = mappingNode.execute(function, arguments); Object[] mappedArguments = mappingNode.execute(function, arguments);
if (mappingNode.appliesFully()) { if (mappingNode.appliesFully()) {
@ -84,4 +87,13 @@ public abstract class ArgumentSorterNode extends BaseNode {
CallArgumentInfo[] getSchema() { CallArgumentInfo[] getSchema() {
return schema; return schema;
} }
/**
* Checks whether the function whose arguments are being sorted has suspended defaults arguments.
*
* @return {@code true} if it has suspended defaults, otherwise {@code false}
*/
boolean hasDefaultsSuspended() {
return this.hasDefaultsSuspended;
}
} }

View File

@ -13,6 +13,7 @@ import org.enso.interpreter.runtime.callable.function.Function;
*/ */
@NodeInfo(shortName = "CachedArgumentSorter") @NodeInfo(shortName = "CachedArgumentSorter")
public class CachedArgumentSorterNode extends BaseNode { public class CachedArgumentSorterNode extends BaseNode {
private final Function originalFunction; private final Function originalFunction;
private final @CompilationFinal(dimensions = 1) int[] mapping; private final @CompilationFinal(dimensions = 1) int[] mapping;
private final ArgumentSchema postApplicationSchema; private final ArgumentSchema postApplicationSchema;
@ -22,9 +23,12 @@ public class CachedArgumentSorterNode extends BaseNode {
* Creates a node that generates and then caches the argument mapping. * Creates a node that generates and then caches the argument mapping.
* *
* @param function the function to sort arguments for * @param function the function to sort arguments for
* @param schema information on the calling arguments * @param schema information on the calling argument
* @param hasDefaultsSuspended whether or not the function to which these arguments are applied
* has its defaults suspended.
*/ */
public CachedArgumentSorterNode(Function function, CallArgumentInfo[] schema) { public CachedArgumentSorterNode(
Function function, CallArgumentInfo[] schema, boolean hasDefaultsSuspended) {
this.originalFunction = function; this.originalFunction = function;
CallArgumentInfo.ArgumentMapping mapping = CallArgumentInfo.ArgumentMapping mapping =
CallArgumentInfo.ArgumentMapping.generate(function.getSchema(), schema); CallArgumentInfo.ArgumentMapping.generate(function.getSchema(), schema);
@ -33,7 +37,10 @@ public class CachedArgumentSorterNode extends BaseNode {
boolean fullApplication = true; boolean fullApplication = true;
for (int i = 0; i < postApplicationSchema.getArgumentsCount(); i++) { for (int i = 0; i < postApplicationSchema.getArgumentsCount(); i++) {
if (!(postApplicationSchema.hasDefaultAt(i) || postApplicationSchema.hasPreAppliedAt(i))) { boolean hasValidDefault = postApplicationSchema.hasDefaultAt(i) && !hasDefaultsSuspended;
boolean hasPreappliedArg = postApplicationSchema.hasPreAppliedAt(i);
if (!(hasValidDefault || hasPreappliedArg)) {
fullApplication = false; fullApplication = false;
break; break;
} }
@ -48,8 +55,9 @@ public class CachedArgumentSorterNode extends BaseNode {
* @param schema information on the calling arguments * @param schema information on the calling arguments
* @return a sorter node for the arguments in {@code schema} being passed to {@code callable} * @return a sorter node for the arguments in {@code schema} being passed to {@code callable}
*/ */
public static CachedArgumentSorterNode create(Function function, CallArgumentInfo[] schema) { public static CachedArgumentSorterNode create(
return new CachedArgumentSorterNode(function, schema); Function function, CallArgumentInfo[] schema, boolean hasDefaultsSuspended) {
return new CachedArgumentSorterNode(function, schema, hasDefaultsSuspended);
} }
/** /**

View File

@ -7,6 +7,12 @@ import com.oracle.truffle.api.TruffleLanguage;
import com.oracle.truffle.api.TruffleLanguage.Env; import com.oracle.truffle.api.TruffleLanguage.Env;
import com.oracle.truffle.api.frame.FrameDescriptor; import com.oracle.truffle.api.frame.FrameDescriptor;
import com.oracle.truffle.api.source.Source; import com.oracle.truffle.api.source.Source;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.enso.interpreter.AstGlobalScope; import org.enso.interpreter.AstGlobalScope;
import org.enso.interpreter.Constants; import org.enso.interpreter.Constants;
import org.enso.interpreter.EnsoParser; import org.enso.interpreter.EnsoParser;
@ -20,12 +26,6 @@ import org.enso.interpreter.util.ScalaConversions;
import org.enso.pkg.Package; import org.enso.pkg.Package;
import org.enso.pkg.SourceFile; import org.enso.pkg.SourceFile;
import java.io.*;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
/** /**
* The language context is the internal state of the language that is associated with each thread in * The language context is the internal state of the language that is associated with each thread in
* a running Enso program. * a running Enso program.

View File

@ -28,7 +28,8 @@ trait AstExpressionVisitor[+T] {
def visitFunctionApplication( def visitFunctionApplication(
function: AstExpression, function: AstExpression,
arguments: java.util.List[AstCallArg] arguments: java.util.List[AstCallArg],
defaultsSuspended: Boolean
): T ): T
def visitIf( def visitIf(
@ -156,10 +157,13 @@ case class AstVariable(name: String) extends AstExpression {
visitor.visitVariable(name) visitor.visitVariable(name)
} }
case class AstApply(fun: AstExpression, args: List[AstCallArg]) case class AstApply(
fun: AstExpression,
args: List[AstCallArg],
hasDefaultsSuspended: Boolean)
extends AstExpression { extends AstExpression {
override def visit[T](visitor: AstExpressionVisitor[T]): T = override def visit[T](visitor: AstExpressionVisitor[T]): T =
visitor.visitFunctionApplication(fun, args.asJava) visitor.visitFunctionApplication(fun, args.asJava, hasDefaultsSuspended)
} }
case class AstFunction( case class AstFunction(
@ -269,7 +273,7 @@ class EnsoParserInternal extends JavaTokenParsers {
def variable: Parser[AstVariable] = ident ^^ AstVariable def variable: Parser[AstVariable] = ident ^^ AstVariable
def operand: Parser[AstExpression] = def operand: Parser[AstExpression] =
long | foreign | variable | "(" ~> expression <~ ")" | call long | foreign | variable | "(" ~> expression <~ ")" | functionCall
def arith: Parser[AstExpression] = def arith: Parser[AstExpression] =
operand ~ ((("+" | "-" | "*" | "/" | "%") ~ operand) ?) ^^ { operand ~ ((("+" | "-" | "*" | "/" | "%") ~ operand) ?) ^^ {
@ -280,9 +284,17 @@ class EnsoParserInternal extends JavaTokenParsers {
def expression: Parser[AstExpression] = def expression: Parser[AstExpression] =
ifZero | matchClause | arith | function ifZero | matchClause | arith | function
def call: Parser[AstApply] = "@" ~> expression ~ (argList ?) ^^ { def functionCall: Parser[AstApply] =
case expr ~ args => AstApply(expr, args.getOrElse(Nil)) "@" ~> expression ~ (argList ?) ~ defaultSuspend ^^ {
} case expr ~ args ~ hasDefaultsSuspended =>
AstApply(expr, args.getOrElse(Nil), hasDefaultsSuspended)
}
def defaultSuspend: Parser[Boolean] =
("..." ?) ^^ ({
case Some(_) => true
case None => false
})
def assignment: Parser[AstAssignment] = ident ~ ("=" ~> expression) ^^ { def assignment: Parser[AstAssignment] = ident ~ ("=" ~> expression) ^^ {
case v ~ exp => AstAssignment(v, exp) case v ~ exp => AstAssignment(v, exp)

View File

@ -14,4 +14,34 @@ class CurryingTest extends LanguageTest {
|""".stripMargin |""".stripMargin
eval(code) shouldEqual 11 eval(code) shouldEqual 11
} }
"Functions" should "allow default arguments to be suspended" in {
val code =
"""
|@{
| fn = { |w, x, y = 10, z = 20| (w + x) + (y + z) };
|
| fn1 = @fn ...;
| fn2 = @fn1 [1, 2] ...;
| fn3 = @fn2 [3] ...;
|
| @fn3
|}
|""".stripMargin
eval(code) shouldEqual 26
}
"Functions" should "allow defaults to be suspended in application chains" in {
val code =
"""
|@{
| fn = { |w, x, y = 10, z = 20| (w + x) + (y + z) };
|
| @(@fn [3, 6] ...) [3]
|}
|""".stripMargin
eval(code) shouldEqual 32
}
} }