Implement suspension of default arguments (#170)

This commit is contained in:
Ara Adkins 2019-09-05 18:26:21 +01:00 committed by GitHub
parent f8dea12e44
commit ed8223c57c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 140 additions and 45 deletions

View File

@ -12,6 +12,7 @@ import org.enso.interpreter.runtime.scope.LocalScope;
* runtime nodes used by the interpreter to guide function evaluation.
*/
public class CallArgFactory implements AstCallArgVisitor<CallArgument> {
private final LocalScope scope;
private final Language language;
private final String scopeName;

View File

@ -4,19 +4,42 @@ import com.oracle.truffle.api.RootCallTarget;
import com.oracle.truffle.api.Truffle;
import com.oracle.truffle.api.frame.FrameSlot;
import com.oracle.truffle.api.nodes.RootNode;
import org.enso.interpreter.*;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.enso.interpreter.AstArgDefinition;
import org.enso.interpreter.AstCallArg;
import org.enso.interpreter.AstCase;
import org.enso.interpreter.AstCaseFunction;
import org.enso.interpreter.AstExpression;
import org.enso.interpreter.AstExpressionVisitor;
import org.enso.interpreter.Language;
import org.enso.interpreter.node.EnsoRootNode;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.node.callable.InvokeCallableNodeGen;
import org.enso.interpreter.node.callable.argument.ReadArgumentNode;
import org.enso.interpreter.node.callable.function.CreateFunctionNode;
import org.enso.interpreter.node.callable.function.FunctionBodyNode;
import org.enso.interpreter.node.controlflow.*;
import org.enso.interpreter.node.controlflow.CaseNode;
import org.enso.interpreter.node.controlflow.ConstructorCaseNode;
import org.enso.interpreter.node.controlflow.DefaultFallbackNode;
import org.enso.interpreter.node.controlflow.FallbackNode;
import org.enso.interpreter.node.controlflow.IfZeroNode;
import org.enso.interpreter.node.controlflow.MatchNode;
import org.enso.interpreter.node.expression.builtin.PrintNode;
import org.enso.interpreter.node.expression.constant.ConstructorNode;
import org.enso.interpreter.node.expression.constant.DynamicSymbolNode;
import org.enso.interpreter.node.expression.literal.IntegerLiteralNode;
import org.enso.interpreter.node.expression.operator.*;
import org.enso.interpreter.node.expression.operator.AddOperatorNodeGen;
import org.enso.interpreter.node.expression.operator.DivideOperatorNodeGen;
import org.enso.interpreter.node.expression.operator.ModOperatorNodeGen;
import org.enso.interpreter.node.expression.operator.MultiplyOperatorNodeGen;
import org.enso.interpreter.node.expression.operator.SubtractOperatorNodeGen;
import org.enso.interpreter.node.scope.AssignmentNode;
import org.enso.interpreter.node.scope.AssignmentNodeGen;
import org.enso.interpreter.node.scope.ReadLocalTargetNodeGen;
@ -27,16 +50,12 @@ import org.enso.interpreter.runtime.error.DuplicateArgumentNameException;
import org.enso.interpreter.runtime.scope.ModuleScope;
import org.enso.interpreter.runtime.scope.LocalScope;
import java.util.*;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* An {@code ExpressionFactory} is responsible for converting the majority of Enso's parsed AST into
* nodes evaluated by the interpreter at runtime.
*/
public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
private final LocalScope scope;
private final Language language;
private final String scopeName;
@ -128,11 +147,21 @@ public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
String operator, AstExpression leftAst, AstExpression rightAst) {
ExpressionNode left = leftAst.visit(this);
ExpressionNode right = rightAst.visit(this);
if (operator.equals("+")) return AddOperatorNodeGen.create(left, right);
if (operator.equals("-")) return SubtractOperatorNodeGen.create(left, right);
if (operator.equals("*")) return MultiplyOperatorNodeGen.create(left, right);
if (operator.equals("/")) return DivideOperatorNodeGen.create(left, right);
if (operator.equals("%")) return ModOperatorNodeGen.create(left, right);
if (operator.equals("+")) {
return AddOperatorNodeGen.create(left, right);
}
if (operator.equals("-")) {
return SubtractOperatorNodeGen.create(left, right);
}
if (operator.equals("*")) {
return MultiplyOperatorNodeGen.create(left, right);
}
if (operator.equals("/")) {
return DivideOperatorNodeGen.create(left, right);
}
if (operator.equals("%")) {
return ModOperatorNodeGen.create(left, right);
}
return null;
}
@ -300,7 +329,7 @@ public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
*/
@Override
public ExpressionNode visitFunctionApplication(
AstExpression function, List<AstCallArg> arguments) {
AstExpression function, List<AstCallArg> arguments, boolean hasDefaultsSuspended) {
CallArgFactory argFactory = new CallArgFactory(scope, language, scopeName, moduleScope);
List<CallArgument> callArgs = new ArrayList<>();
@ -310,7 +339,7 @@ public class ExpressionFactory implements AstExpressionVisitor<ExpressionNode> {
}
return InvokeCallableNodeGen.create(
callArgs.stream().toArray(CallArgument[]::new), function.visit(this));
callArgs.toArray(new CallArgument[0]), hasDefaultsSuspended, function.visit(this));
}
/**

View File

@ -8,6 +8,7 @@ import com.oracle.truffle.api.frame.VirtualFrame;
import com.oracle.truffle.api.nodes.ExplodeLoop;
import com.oracle.truffle.api.nodes.NodeInfo;
import com.oracle.truffle.api.profiles.ConditionProfile;
import java.util.Arrays;
import org.enso.interpreter.Constants;
import org.enso.interpreter.node.ExpressionNode;
import org.enso.interpreter.node.callable.argument.sorter.ArgumentSorterNode;
@ -22,8 +23,6 @@ import org.enso.interpreter.runtime.error.MethodDoesNotExistException;
import org.enso.interpreter.runtime.error.NotInvokableException;
import org.enso.interpreter.runtime.type.TypesGen;
import java.util.Arrays;
/**
* This node is responsible for organising callable calls so that they are ready to be made.
*
@ -33,15 +32,17 @@ import java.util.Arrays;
@NodeInfo(shortName = "@", description = "Executes function")
@NodeChild(value = "callable", type = ExpressionNode.class)
public abstract class InvokeCallableNode extends ExpressionNode {
@Children
private @CompilationFinal(dimensions = 1) ExpressionNode[] argExpressions;
private final boolean canApplyThis;
private final int thisArgumentPosition;
@Children
@CompilationFinal(dimensions = 1)
private ExpressionNode[] argExpressions;
@Child private ArgumentSorterNode argumentSorter;
@Child private MethodResolverNode methodResolverNode;
private final boolean canApplyThis;
private final int thisArgumentPosition;
private final ConditionProfile methodCalledOnNonAtom = ConditionProfile.createCountingProfile();
/**
@ -49,7 +50,7 @@ public abstract class InvokeCallableNode extends ExpressionNode {
*
* @param callArguments information on the arguments being passed to the {@link Function}
*/
public InvokeCallableNode(CallArgument[] callArguments) {
public InvokeCallableNode(CallArgument[] callArguments, boolean hasDefaultsSuspended) {
this.argExpressions =
Arrays.stream(callArguments)
.map(CallArgument::getExpression)
@ -57,9 +58,9 @@ public abstract class InvokeCallableNode extends ExpressionNode {
CallArgumentInfo[] argSchema =
Arrays.stream(callArguments).map(CallArgumentInfo::new).toArray(CallArgumentInfo[]::new);
boolean appliesThis = false;
int idx = 0;
for (; idx < argSchema.length; idx++) {
CallArgumentInfo arg = argSchema[idx];
if (arg.isPositional()
@ -68,15 +69,17 @@ public abstract class InvokeCallableNode extends ExpressionNode {
break;
}
}
this.canApplyThis = appliesThis;
this.thisArgumentPosition = idx;
this.argumentSorter = ArgumentSorterNodeGen.create(argSchema);
this.argumentSorter = ArgumentSorterNodeGen.create(argSchema, hasDefaultsSuspended);
this.methodResolverNode = MethodResolverNodeGen.create();
}
/**
* Marks whether the {@code argumentSorter} child is tailrecursive.
*
* @param isTail whether or not the node is tail-recursive.
*/
@Override

View File

@ -16,18 +16,20 @@ import org.enso.interpreter.runtime.callable.function.Function;
*/
@NodeInfo(shortName = "ArgumentSorter")
public abstract class ArgumentSorterNode extends BaseNode {
private @CompilationFinal(dimensions = 1) CallArgumentInfo[] schema;
private final boolean hasDefaultsSuspended;
/**
* Creates a node that performs the argument organisation for the provided schema.
*
* @param schema information about the call arguments in positional order
*/
public ArgumentSorterNode(CallArgumentInfo[] schema) {
public ArgumentSorterNode(CallArgumentInfo[] schema, boolean hasDefaultsSuspended) {
this.schema = schema;
this.hasDefaultsSuspended = hasDefaultsSuspended;
}
/**
* Generates the argument mapping where it has already been computed and executes the function.
*
@ -49,7 +51,8 @@ public abstract class ArgumentSorterNode extends BaseNode {
public Object invokeCached(
Function function,
Object[] arguments,
@Cached("create(function, getSchema())") CachedArgumentSorterNode mappingNode,
@Cached("create(function, getSchema(), hasDefaultsSuspended())")
CachedArgumentSorterNode mappingNode,
@Cached CallOptimiserNode optimiser) {
Object[] mappedArguments = mappingNode.execute(function, arguments);
if (mappingNode.appliesFully()) {
@ -84,4 +87,13 @@ public abstract class ArgumentSorterNode extends BaseNode {
CallArgumentInfo[] getSchema() {
return schema;
}
/**
* Checks whether the function whose arguments are being sorted has suspended defaults arguments.
*
* @return {@code true} if it has suspended defaults, otherwise {@code false}
*/
boolean hasDefaultsSuspended() {
return this.hasDefaultsSuspended;
}
}

View File

@ -13,6 +13,7 @@ import org.enso.interpreter.runtime.callable.function.Function;
*/
@NodeInfo(shortName = "CachedArgumentSorter")
public class CachedArgumentSorterNode extends BaseNode {
private final Function originalFunction;
private final @CompilationFinal(dimensions = 1) int[] mapping;
private final ArgumentSchema postApplicationSchema;
@ -22,9 +23,12 @@ public class CachedArgumentSorterNode extends BaseNode {
* Creates a node that generates and then caches the argument mapping.
*
* @param function the function to sort arguments for
* @param schema information on the calling arguments
* @param schema information on the calling argument
* @param hasDefaultsSuspended whether or not the function to which these arguments are applied
* has its defaults suspended.
*/
public CachedArgumentSorterNode(Function function, CallArgumentInfo[] schema) {
public CachedArgumentSorterNode(
Function function, CallArgumentInfo[] schema, boolean hasDefaultsSuspended) {
this.originalFunction = function;
CallArgumentInfo.ArgumentMapping mapping =
CallArgumentInfo.ArgumentMapping.generate(function.getSchema(), schema);
@ -33,7 +37,10 @@ public class CachedArgumentSorterNode extends BaseNode {
boolean fullApplication = true;
for (int i = 0; i < postApplicationSchema.getArgumentsCount(); i++) {
if (!(postApplicationSchema.hasDefaultAt(i) || postApplicationSchema.hasPreAppliedAt(i))) {
boolean hasValidDefault = postApplicationSchema.hasDefaultAt(i) && !hasDefaultsSuspended;
boolean hasPreappliedArg = postApplicationSchema.hasPreAppliedAt(i);
if (!(hasValidDefault || hasPreappliedArg)) {
fullApplication = false;
break;
}
@ -48,8 +55,9 @@ public class CachedArgumentSorterNode extends BaseNode {
* @param schema information on the calling arguments
* @return a sorter node for the arguments in {@code schema} being passed to {@code callable}
*/
public static CachedArgumentSorterNode create(Function function, CallArgumentInfo[] schema) {
return new CachedArgumentSorterNode(function, schema);
public static CachedArgumentSorterNode create(
Function function, CallArgumentInfo[] schema, boolean hasDefaultsSuspended) {
return new CachedArgumentSorterNode(function, schema, hasDefaultsSuspended);
}
/**

View File

@ -7,6 +7,12 @@ import com.oracle.truffle.api.TruffleLanguage;
import com.oracle.truffle.api.TruffleLanguage.Env;
import com.oracle.truffle.api.frame.FrameDescriptor;
import com.oracle.truffle.api.source.Source;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
import org.enso.interpreter.AstGlobalScope;
import org.enso.interpreter.Constants;
import org.enso.interpreter.EnsoParser;
@ -20,12 +26,6 @@ import org.enso.interpreter.util.ScalaConversions;
import org.enso.pkg.Package;
import org.enso.pkg.SourceFile;
import java.io.*;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
/**
* The language context is the internal state of the language that is associated with each thread in
* a running Enso program.

View File

@ -28,7 +28,8 @@ trait AstExpressionVisitor[+T] {
def visitFunctionApplication(
function: AstExpression,
arguments: java.util.List[AstCallArg]
arguments: java.util.List[AstCallArg],
defaultsSuspended: Boolean
): T
def visitIf(
@ -156,10 +157,13 @@ case class AstVariable(name: String) extends AstExpression {
visitor.visitVariable(name)
}
case class AstApply(fun: AstExpression, args: List[AstCallArg])
case class AstApply(
fun: AstExpression,
args: List[AstCallArg],
hasDefaultsSuspended: Boolean)
extends AstExpression {
override def visit[T](visitor: AstExpressionVisitor[T]): T =
visitor.visitFunctionApplication(fun, args.asJava)
visitor.visitFunctionApplication(fun, args.asJava, hasDefaultsSuspended)
}
case class AstFunction(
@ -269,7 +273,7 @@ class EnsoParserInternal extends JavaTokenParsers {
def variable: Parser[AstVariable] = ident ^^ AstVariable
def operand: Parser[AstExpression] =
long | foreign | variable | "(" ~> expression <~ ")" | call
long | foreign | variable | "(" ~> expression <~ ")" | functionCall
def arith: Parser[AstExpression] =
operand ~ ((("+" | "-" | "*" | "/" | "%") ~ operand) ?) ^^ {
@ -280,9 +284,17 @@ class EnsoParserInternal extends JavaTokenParsers {
def expression: Parser[AstExpression] =
ifZero | matchClause | arith | function
def call: Parser[AstApply] = "@" ~> expression ~ (argList ?) ^^ {
case expr ~ args => AstApply(expr, args.getOrElse(Nil))
}
def functionCall: Parser[AstApply] =
"@" ~> expression ~ (argList ?) ~ defaultSuspend ^^ {
case expr ~ args ~ hasDefaultsSuspended =>
AstApply(expr, args.getOrElse(Nil), hasDefaultsSuspended)
}
def defaultSuspend: Parser[Boolean] =
("..." ?) ^^ ({
case Some(_) => true
case None => false
})
def assignment: Parser[AstAssignment] = ident ~ ("=" ~> expression) ^^ {
case v ~ exp => AstAssignment(v, exp)

View File

@ -14,4 +14,34 @@ class CurryingTest extends LanguageTest {
|""".stripMargin
eval(code) shouldEqual 11
}
"Functions" should "allow default arguments to be suspended" in {
val code =
"""
|@{
| fn = { |w, x, y = 10, z = 20| (w + x) + (y + z) };
|
| fn1 = @fn ...;
| fn2 = @fn1 [1, 2] ...;
| fn3 = @fn2 [3] ...;
|
| @fn3
|}
|""".stripMargin
eval(code) shouldEqual 26
}
"Functions" should "allow defaults to be suspended in application chains" in {
val code =
"""
|@{
| fn = { |w, x, y = 10, z = 20| (w + x) + (y + z) };
|
| @(@fn [3, 6] ...) [3]
|}
|""".stripMargin
eval(code) shouldEqual 32
}
}